From a0114f45bdc450e43e22cc48a14e53a1f00ed3eb Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Mon, 13 Nov 2023 10:34:28 -0800 Subject: [PATCH] Revert "Integration with REPL Spark job (#2327) (#2338)" This reverts commit 58a5ae5e765fb09a3741e32cc746da23d8f7a6df. --- .../org/opensearch/sql/plugin/SQLPlugin.java | 5 +- spark/build.gradle | 1 - .../model/SparkSubmitParameters.java | 14 +- .../spark/data/constants/SparkConstants.java | 4 - .../dispatcher/SparkQueryDispatcher.java | 67 ++-- .../session/CreateSessionRequest.java | 21 +- .../execution/session/InteractiveSession.java | 15 +- .../spark/execution/session/SessionId.java | 21 +- .../execution/session/SessionManager.java | 5 +- .../spark/execution/session/SessionState.java | 7 +- .../spark/execution/session/SessionType.java | 14 +- .../spark/execution/statement/Statement.java | 20 +- .../execution/statement/StatementModel.java | 10 - .../execution/statement/StatementState.java | 7 +- .../execution/statestore/StateStore.java | 203 +++------- .../response/JobExecutionResponseReader.java | 4 - .../query_execution_request_mapping.yml | 40 -- .../query_execution_request_settings.yml | 11 - ...AsyncQueryExecutorServiceImplSpecTest.java | 374 ------------------ .../dispatcher/SparkQueryDispatcherTest.java | 6 +- .../session/InteractiveSessionTest.java | 55 +-- .../execution/statement/StatementTest.java | 63 ++- 22 files changed, 157 insertions(+), 810 deletions(-) delete mode 100644 spark/src/main/resources/query_execution_request_mapping.yml delete mode 100644 spark/src/main/resources/query_execution_request_settings.yml delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f714a8366b..eb6eabf988 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -320,7 +321,9 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), + emrServerlessClient, + pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/build.gradle b/spark/build.gradle index 8f4388495e..15f1e200e0 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -68,7 +68,6 @@ dependencies { because 'allows tests to run from IDEs that bundle older version of launcher' } testImplementation("org.opensearch.test:framework:${opensearch_version}") - testImplementation project(':opensearch') } test { diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index db78abb2a8..0609d8903c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -12,7 +12,6 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; import static org.opensearch.sql.spark.data.constants.SparkConstants.*; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.net.URI; import java.net.URISyntaxException; @@ -40,7 +39,7 @@ public class SparkSubmitParameters { public static class Builder { - private String className; + private final String className; private final Map config; private String extraParameters; @@ -71,11 +70,6 @@ public static Builder builder() { return new Builder(); } - public Builder className(String className) { - this.className = className; - return this; - } - public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); @@ -147,12 +141,6 @@ public Builder extraParameters(String params) { return this; } - public Builder sessionExecution(String sessionId, String datasourceName) { - config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - config.put(FLINT_JOB_SESSION_ID, sessionId); - return this; - } - public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 85ce3c4989..1b248eb15d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -87,8 +87,4 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; - - public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; - public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; - public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 2bd1ae67b9..8d5ae10e91 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -97,19 +96,12 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); } - JSONObject result; - if (asyncQueryJobMetadata.getSessionId() == null) { - // either empty json when the result is not available or data with status - // Fetch from Result Index - result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - } else { - // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. - result = - jobExecutionResponseReader.getResultWithQueryId( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - } + // either empty json when the result is not available or data with status + // Fetch from Result Index + JSONObject result = + jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + // if result index document has a status, we are gonna use the status directly; otherwise, we // will use emr-s job status. // That a job is successful does not mean there is no error in execution. For example, even if @@ -238,7 +230,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); if (sessionManager.isEnabled()) { Session session; if (dispatchQueryRequest.getSessionId() != null) { @@ -253,19 +260,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // create session if not exist session = sessionManager.createSession( - new CreateSessionRequest( - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .className(FLINT_SESSION_CLASS_NAME) - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), - tags, - dataSourceMetadata.getResultIndex(), - dataSourceMetadata.getName())); + new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); } StatementId statementId = session.submit( @@ -277,22 +272,6 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); } else { - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - false, - dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index ca2b2b4867..17e3346248 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -5,30 +5,11 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Map; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; @Data public class CreateSessionRequest { - private final String jobName; - private final String applicationId; - private final String executionRoleArn; - private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; - private final Map tags; - private final String resultIndex; + private final StartJobRequest startJobRequest; private final String datasourceName; - - public StartJobRequest getStartJobRequest() { - return new StartJobRequest( - "select 1", - jobName, - applicationId, - executionRoleArn, - sparkSubmitParametersBuilder.build().toString(), - tags, - false, - resultIndex); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 4428c3b83d..e33ef4245a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -42,17 +42,13 @@ public class InteractiveSession implements Session { @Override public void open(CreateSessionRequest createSessionRequest) { try { - // append session id; - createSessionRequest - .getSparkSubmitParametersBuilder() - .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); + createSession(stateStore).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -63,8 +59,7 @@ public void open(CreateSessionRequest createSessionRequest) { /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + Optional model = getSession(stateStore).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -74,8 +69,7 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { - Optional model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + Optional model = getSession(stateStore).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -90,7 +84,6 @@ public StatementId submit(QueryRequest request) { .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) - .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) .queryId(statementId.getId()) .build(); @@ -110,7 +103,7 @@ public StatementId submit(QueryRequest request) { @Override public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) + return StateStore.getStatement(stateStore) .apply(stID.getId()) .map( model -> diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index b3bd716925..861d906b9b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,32 +5,15 @@ package org.opensearch.sql.spark.execution.session; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import lombok.Data; import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { - public static final int PREFIX_LEN = 10; - private final String sessionId; - public static SessionId newSessionId(String datasourceName) { - return new SessionId(encode(datasourceName)); - } - - public String getDataSourceName() { - return decode(sessionId); - } - - private static String decode(String sessionId) { - return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); - } - - private static String encode(String datasourceName) { - String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; - return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); + public static SessionId newSessionId() { + return new SessionId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c0f7bbcde8..c34be7015f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -28,7 +28,7 @@ public class SessionManager { public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() - .sessionId(newSessionId(request.getDatasourceName())) + .sessionId(newSessionId()) .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); @@ -37,8 +37,7 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = - StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); + Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index bd5d14c603..a4da957f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -8,7 +8,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -33,10 +32,8 @@ public enum SessionState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static SessionState fromString(String key) { - for (SessionState ss : SessionState.values()) { - if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) { - return ss; - } + if (STATES.containsKey(key)) { + return STATES.get(key); } throw new IllegalArgumentException("Invalid session state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java index 10b9ce7bd5..dd179a1dc5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Locale; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; import lombok.Getter; @Getter @@ -18,11 +20,13 @@ public enum SessionType { this.sessionType = sessionType; } + private static Map TYPES = + Arrays.stream(SessionType.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + public static SessionType fromString(String key) { - for (SessionType sType : SessionType.values()) { - if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) { - return sType; - } + if (TYPES.containsKey(key)) { + return TYPES.get(key); } throw new IllegalArgumentException("Invalid session type: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index d84c91bdb8..8fcedb5fca 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -32,7 +32,6 @@ public class Statement { private final String jobId; private final StatementId statementId; private final LangType langType; - private final String datasourceName; private final String query; private final String queryId; private final StateStore stateStore; @@ -43,16 +42,8 @@ public class Statement { public void open() { try { statementModel = - submitStatement( - sessionId, - applicationId, - jobId, - statementId, - langType, - datasourceName, - query, - queryId); - statementModel = createStatement(stateStore, datasourceName).apply(statementModel); + submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); + statementModel = createStatement(stateStore).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -70,8 +61,7 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore, statementModel.getDatasourceName()) - .apply(this.statementModel, StatementState.CANCELLED); + updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -79,9 +69,7 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore, statementModel.getDatasourceName()) - .apply(statementModel.getId()) - .orElse(this.statementModel); + getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); String errorMsg = String.format( "cancel statement failed. current statementState: %s " + "statement: %s.", diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index 2a1043bf73..c7f681c541 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; -import static org.opensearch.sql.spark.execution.session.SessionModel.DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; @@ -46,7 +45,6 @@ public class StatementModel extends StateModel { private final String applicationId; private final String jobId; private final LangType langType; - private final String datasourceName; private final String query; private final String queryId; private final long submitTime; @@ -67,7 +65,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(APPLICATION_ID, applicationId) .field(JOB_ID, jobId) .field(LANG, langType.getText()) - .field(DATASOURCE_NAME, datasourceName) .field(QUERY, query) .field(QUERY_ID, queryId) .field(SUBMIT_TIME, submitTime) @@ -85,7 +82,6 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) - .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -105,7 +101,6 @@ public static StatementModel copyWithState( .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) - .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -148,9 +143,6 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case LANG: builder.langType(LangType.fromString(parser.text())); break; - case DATASOURCE_NAME: - builder.datasourceName(parser.text()); - break; case QUERY: builder.query(parser.text()); break; @@ -176,7 +168,6 @@ public static StatementModel submitStatement( String jobId, StatementId statementId, LangType langType, - String datasourceName, String query, String queryId) { return builder() @@ -187,7 +178,6 @@ public static StatementModel submitStatement( .applicationId(applicationId) .jobId(jobId) .langType(langType) - .datasourceName(datasourceName) .query(query) .queryId(queryId) .submitTime(System.currentTimeMillis()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 48978ff8f9..33f7f5e831 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.execution.statement; import java.util.Arrays; -import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -31,10 +30,8 @@ public enum StatementState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static StatementState fromString(String key) { - for (StatementState ss : StatementState.values()) { - if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) { - return ss; - } + if (STATES.containsKey(key)) { + return STATES.get(key); } throw new IllegalArgumentException("Invalid statement state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index a36ee3ef45..bd72b17353 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -5,22 +5,15 @@ package org.opensearch.sql.spark.execution.statestore; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; - import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; import java.util.Locale; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import lombok.RequiredArgsConstructor; -import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; @@ -29,9 +22,6 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -43,29 +33,15 @@ import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; -/** - * State Store maintain the state of Session and Statement. State State create/update/get doc on - * index regardless user FGAC permissions. - */ @RequiredArgsConstructor public class StateStore { - public static String SETTINGS_FILE_NAME = "query_execution_request_settings.yml"; - public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; - public static Function DATASOURCE_TO_REQUEST_INDEX = - datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); - private static final Logger LOG = LogManager.getLogger(); + private final String indexName; private final Client client; - private final ClusterService clusterService; - protected T create( - T st, StateModel.CopyBuilder builder, String indexName) { + protected T create(T st, StateModel.CopyBuilder builder) { try { - if (!this.clusterService.state().routingTable().hasIndex(indexName)) { - createIndex(indexName); - } IndexRequest indexRequest = new IndexRequest(indexName) .id(st.getId()) @@ -74,60 +50,48 @@ protected T create( .setIfPrimaryTerm(st.getPrimaryTerm()) .create(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - ; - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - st.getId(), - indexResponse.getResult().getLowercase())); - } + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); } } catch (IOException e) { throw new RuntimeException(e); } } - protected Optional get( - String sid, StateModel.FromXContent builder, String indexName) { + protected Optional get(String sid, StateModel.FromXContent builder) { try { - if (!this.clusterService.state().routingTable().hasIndex(indexName)) { - createIndex(indexName); + GetRequest getRequest = new GetRequest().index(indexName).id(sid); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { return Optional.empty(); } - GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); - } - } } catch (IOException e) { throw new RuntimeException(e); } } protected T updateState( - T st, S state, StateModel.StateCopyBuilder builder, String indexName) { + T st, S state, StateModel.StateCopyBuilder builder) { try { T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); UpdateRequest updateRequest = @@ -139,110 +103,47 @@ protected T updateState( .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) .fetchSource(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - UpdateResponse updateResponse = client.update(updateRequest).actionGet(); - if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { - LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of( - model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed update doc. id: %s, error: %s", - st.getId(), - updateResponse.getResult().getLowercase())); - } + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); } } catch (IOException e) { throw new RuntimeException(e); } } - private void createIndex(String indexName) { - try { - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest - .mapping(loadConfigFromResource(MAPPING_FILE_NAME), XContentType.YAML) - .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); - ActionFuture createIndexResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); - } - CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); - if (createIndexResponse.isAcknowledged()) { - LOG.info("Index: {} creation Acknowledged", indexName); - } else { - throw new RuntimeException("Index creation is not acknowledged."); - } - } catch (Throwable e) { - throw new RuntimeException( - "Internal server error while creating" + indexName + " index:: " + e.getMessage()); - } - } - - private String loadConfigFromResource(String fileName) throws IOException { - InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); - return IOUtils.toString(fileStream, StandardCharsets.UTF_8); - } - /** Helper Functions */ - public static Function createStatement( - StateStore stateStore, String datasourceName) { - return (st) -> - stateStore.create( - st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + public static Function createStatement(StateStore stateStore) { + return (st) -> stateStore.create(st, StatementModel::copy); } - public static Function> getStatement( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + public static Function> getStatement(StateStore stateStore) { + return (docId) -> stateStore.get(docId, StatementModel::fromXContent); } public static BiFunction updateStatementState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - StatementModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function createSession( - StateStore stateStore, String datasourceName) { - return (session) -> - stateStore.create( - session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); } - public static Function> getSession( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + public static Function createSession(StateStore stateStore) { + return (session) -> stateStore.create(session, SessionModel::of); } - public static Function> searchSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); + public static Function> getSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent); } public static BiFunction updateSessionState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - SessionModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { - String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - return () -> stateStore.createIndex(indexName); + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index 2614992463..d3cbd68dce 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -39,10 +39,6 @@ public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } - public JSONObject getResultWithQueryId(String queryId, String resultIndex) { - return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex); - } - private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml deleted file mode 100644 index 87bd927e6e..0000000000 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ /dev/null @@ -1,40 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Schema file for the .ql-job-metadata index -# Also "dynamic" is set to "false" so that other fields can be added. -dynamic: false -properties: - type: - type: keyword - state: - type: keyword - statementId: - type: keyword - applicationId: - type: keyword - sessionId: - type: keyword - sessionType: - type: keyword - error: - type: text - lang: - type: keyword - query: - type: text - dataSourceName: - type: keyword - submitTime: - type: date - format: strict_date_time||epoch_millis - jobId: - type: keyword - lastUpdateTime: - type: date - format: strict_date_time||epoch_millis - queryId: - type: keyword diff --git a/spark/src/main/resources/query_execution_request_settings.yml b/spark/src/main/resources/query_execution_request_settings.yml deleted file mode 100644 index da2bf07bf1..0000000000 --- a/spark/src/main/resources/query_execution_request_settings.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Settings file for the .ql-job-metadata index -index: - number_of_shards: "1" - auto_expand_replicas: "0-2" - number_of_replicas: "0" diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java deleted file mode 100644 index 3eb8958eb2..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery; - -import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; -import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; -import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; -import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; -import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; - -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; -import com.amazonaws.services.emrserverless.model.JobRun; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import lombok.Getter; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.plugins.Plugin; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; -import org.opensearch.sql.datasources.encryptor.EncryptorImpl; -import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; -import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; -import org.opensearch.sql.datasources.service.DataSourceServiceImpl; -import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; -import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; -import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statement.StatementModel; -import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; -import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; -import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; -import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.sql.storage.DataSourceFactory; -import org.opensearch.test.OpenSearchIntegTestCase; - -public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { - public static final String DATASOURCE = "mys3"; - - private ClusterService clusterService; - private org.opensearch.sql.common.setting.Settings pluginSettings; - private NodeClient client; - private DataSourceServiceImpl dataSourceService; - private StateStore stateStore; - private ClusterSettings clusterSettings; - - @Override - protected Collection> nodePlugins() { - return Arrays.asList(TestSettingPlugin.class); - } - - public static class TestSettingPlugin extends Plugin { - @Override - public List> getSettings() { - return OpenSearchSettings.pluginSettings(); - } - } - - @Before - public void setup() { - clusterService = clusterService(); - clusterSettings = clusterService.getClusterSettings(); - pluginSettings = new OpenSearchSettings(clusterSettings); - client = (NodeClient) cluster().client(); - dataSourceService = createDataSourceService(); - dataSourceService.createDataSource( - new DataSourceMetadata( - DATASOURCE, - DataSourceType.S3GLUE, - ImmutableList.of(), - ImmutableMap.of( - "glue.auth.type", - "iam_role", - "glue.auth.role_arn", - "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", - "glue.indexstore.opensearch.uri", - "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", - "glue.indexstore.opensearch.auth", - "noauth"), - null)); - stateStore = new StateStore(client, clusterService); - createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); - } - - @After - public void clean() { - client - .admin() - .cluster() - .prepareUpdateSettings() - .setTransientSettings( - Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) - .get(); - } - - @Test - public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { - LocalEMRSClient emrsClient = new LocalEMRSClient(); - AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); - - // disable session - enableSession(false); - - // 1. create async query. - CreateAsyncQueryResponse response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); - assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); - emrsClient.startJobRunCalled(1); - - // 2. fetch async query result. - AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); - assertEquals("RUNNING", asyncQueryResults.getStatus()); - emrsClient.getJobRunResultCalled(1); - - // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); - assertEquals(response.getQueryId(), cancelQueryId); - emrsClient.cancelJobRunCalled(1); - } - - @Test - public void createAsyncQueryCreateJobWithCorrectParameters() { - LocalEMRSClient emrsClient = new LocalEMRSClient(); - AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); - - enableSession(false); - CreateAsyncQueryResponse response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); - String params = emrsClient.getJobRequest().getSparkSubmitParams(); - assertNull(response.getSessionId()); - assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); - assertFalse( - params.contains( - String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); - assertFalse( - params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); - - // enable session - enableSession(true); - response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); - params = emrsClient.getJobRequest().getSparkSubmitParams(); - assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); - assertTrue( - params.contains( - String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); - assertTrue( - params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); - } - - @Test - public void withSessionCreateAsyncQueryThenGetResultThenCancel() { - LocalEMRSClient emrsClient = new LocalEMRSClient(); - AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); - - // enable session - enableSession(true); - - // 1. create async query. - CreateAsyncQueryResponse response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); - assertNotNull(response.getSessionId()); - Optional statementModel = - getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); - assertTrue(statementModel.isPresent()); - assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); - - // 2. fetch async query result. - AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); - assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); - - // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); - assertEquals(response.getQueryId(), cancelQueryId); - } - - @Test - public void reuseSessionWhenCreateAsyncQuery() { - LocalEMRSClient emrsClient = new LocalEMRSClient(); - AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); - - // enable session - enableSession(true); - - // 1. create async query. - CreateAsyncQueryResponse first = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); - assertNotNull(first.getSessionId()); - - // 2. reuse session id - CreateAsyncQueryResponse second = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest( - "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); - - assertEquals(first.getSessionId(), second.getSessionId()); - assertNotEquals(first.getQueryId(), second.getQueryId()); - // one session doc. - assertEquals( - 1, - search( - QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) - .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); - // two statement docs has same sessionId. - assertEquals( - 2, - search( - QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) - .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); - - Optional firstModel = - getStatement(stateStore, DATASOURCE).apply(first.getQueryId()); - assertTrue(firstModel.isPresent()); - assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); - assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); - assertEquals(first.getQueryId(), firstModel.get().getQueryId()); - Optional secondModel = - getStatement(stateStore, DATASOURCE).apply(second.getQueryId()); - assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); - assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); - assertEquals(second.getQueryId(), secondModel.get().getQueryId()); - } - - private DataSourceServiceImpl createDataSourceService() { - String masterKey = "1234567890"; - DataSourceMetadataStorage dataSourceMetadataStorage = - new OpenSearchDataSourceMetadataStorage( - client, clusterService, new EncryptorImpl(masterKey)); - return new DataSourceServiceImpl( - new ImmutableSet.Builder() - .add(new GlueDataSourceFactory(pluginSettings)) - .build(), - dataSourceMetadataStorage, - meta -> {}); - } - - private AsyncQueryExecutorService createAsyncQueryExecutorService( - EMRServerlessClient emrServerlessClient) { - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); - JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - this.dataSourceService, - new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader, - new FlintIndexMetadataReaderImpl(client), - client, - new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); - return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, - sparkQueryDispatcher, - this::sparkExecutionEngineConfig); - } - - public static class LocalEMRSClient implements EMRServerlessClient { - - private int startJobRunCalled = 0; - private int cancelJobRunCalled = 0; - private int getJobResult = 0; - - @Getter private StartJobRequest jobRequest; - - @Override - public String startJobRun(StartJobRequest startJobRequest) { - jobRequest = startJobRequest; - startJobRunCalled++; - return "jobId"; - } - - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - getJobResult++; - JobRun jobRun = new JobRun(); - jobRun.setState("RUNNING"); - return new GetJobRunResult().withJobRun(jobRun); - } - - @Override - public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { - cancelJobRunCalled++; - return new CancelJobRunResult().withJobRunId(jobId); - } - - public void startJobRunCalled(int expectedTimes) { - assertEquals(expectedTimes, startJobRunCalled); - } - - public void cancelJobRunCalled(int expectedTimes) { - assertEquals(expectedTimes, cancelJobRunCalled); - } - - public void getJobRunResultCalled(int expectedTimes) { - assertEquals(expectedTimes, getJobResult); - } - } - - public SparkExecutionEngineConfig sparkExecutionEngineConfig() { - return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); - } - - public void enableSession(boolean enabled) { - client - .admin() - .cluster() - .prepareUpdateSettings() - .setTransientSettings( - Settings.builder() - .put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled) - .build()) - .get(); - } - - int search(QueryBuilder query) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(query); - searchRequest.source(searchSourceBuilder); - SearchResponse searchResponse = client.search(searchRequest).actionGet(); - - return searchResponse.getHits().getHits().length; - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 15211dec01..58fe626dae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -726,7 +726,7 @@ void testGetQueryResponseWithSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -740,7 +740,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -759,7 +759,7 @@ void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.empty()).when(session).get(any()); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 06a8d8c73c..429c970365 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -8,12 +8,10 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; -import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -22,17 +20,15 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.OpenSearchSingleNodeTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ -public class InteractiveSessionTest extends OpenSearchIntegTestCase { +public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { - private static final String DS_NAME = "mys3"; - private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + private static final String indexName = "mockindex"; private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -42,21 +38,20 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(client(), clusterService()); + stateStore = new StateStore(indexName, client()); + createIndex(indexName); } @After public void clean() { - if (clusterService().state().routingTable().hasIndex(indexName)) { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); - } + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); } @Test public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() - .sessionId(SessionId.newSessionId(DS_NAME)) + .sessionId(SessionId.newSessionId()) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -64,7 +59,7 @@ public void openCloseSession() { // open session TestSession testSession = testSession(session, stateStore); testSession - .open(createSessionRequest()) + .open(new CreateSessionRequest(startJobRequest, "datasource")) .assertSessionState(NOT_STARTED) .assertAppId("appId") .assertJobId("jobId"); @@ -77,14 +72,14 @@ public void openCloseSession() { @Test public void openSessionFailedConflict() { - SessionId sessionId = SessionId.newSessionId(DS_NAME); + SessionId sessionId = new SessionId("duplicate-session-id"); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(createSessionRequest()); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); InteractiveSession duplicateSession = InteractiveSession.builder() @@ -94,20 +89,21 @@ public void openSessionFailedConflict() { .build(); IllegalStateException exception = assertThrows( - IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); - assertEquals("session already exist. " + sessionId, exception.getMessage()); + IllegalStateException.class, + () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); + assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); } @Test public void closeNotExistSession() { - SessionId sessionId = SessionId.newSessionId(DS_NAME); + SessionId sessionId = SessionId.newSessionId(); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(createSessionRequest()); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -120,7 +116,7 @@ public void closeNotExistSession() { public void sessionManagerCreateSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); TestSession testSession = testSession(session, stateStore); testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); @@ -130,7 +126,8 @@ public void sessionManagerCreateSession() { public void sessionManagerGetSession() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); Optional managerSession = sessionManager.getSession(session.getSessionId()); assertTrue(managerSession.isPresent()); @@ -142,8 +139,7 @@ public void sessionManagerGetSessionNotExist() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Optional managerSession = - sessionManager.getSession(SessionId.newSessionId("no-exist")); + Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); assertTrue(managerSession.isEmpty()); } @@ -160,7 +156,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); + getSession(stateStore).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -188,17 +184,6 @@ public TestSession close() { } } - public static CreateSessionRequest createSessionRequest() { - return new CreateSessionRequest( - "jobName", - "appId", - "arn", - SparkSubmitParameters.Builder.builder(), - ImmutableMap.of(), - "resultIndex", - DS_NAME); - } - public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index ff3ddd1bef..214bcb8258 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,16 +5,15 @@ package org.opensearch.sql.spark.execution.statement; -import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; +import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.junit.After; @@ -22,6 +21,8 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -29,27 +30,27 @@ import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.OpenSearchSingleNodeTestCase; -public class StatementTest extends OpenSearchIntegTestCase { +public class StatementTest extends OpenSearchSingleNodeTestCase { - private static final String DS_NAME = "mys3"; - private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + private static final String indexName = "mockindex"; + private StartJobRequest startJobRequest; private StateStore stateStore; private InteractiveSessionTest.TestEMRServerlessClient emrsClient = new InteractiveSessionTest.TestEMRServerlessClient(); @Before public void setup() { - stateStore = new StateStore(client(), clusterService()); + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new StateStore(indexName, client()); + createIndex(indexName); } @After public void clean() { - if (clusterService().state().routingTable().hasIndex(indexName)) { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); - } + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); } @Test @@ -61,7 +62,6 @@ public void openThenCancelStatement() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -87,7 +87,6 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -102,7 +101,6 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -121,14 +119,13 @@ public void cancelNotExistStatement() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) .build(); st.open(); - client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); + client().delete(new DeleteRequest(indexName, stId.getId())); IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( @@ -146,7 +143,6 @@ public void cancelFailedBecauseOfConflict() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -154,7 +150,7 @@ public void cancelFailedBecauseOfConflict() { st.open(); StatementModel running = - updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); + updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); assertEquals(StatementState.CANCELLED, running.getStatementState()); @@ -176,7 +172,6 @@ public void cancelRunningStatementFailed() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) - .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -203,10 +198,10 @@ public void cancelRunningStatementFailed() { public void submitStatementInRunningSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -216,7 +211,7 @@ public void submitStatementInRunningSession() { public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -226,9 +221,9 @@ public void submitStatementInNotStartedState() { public void failToSubmitStatementInDeadState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = assertThrows( @@ -244,9 +239,9 @@ public void failToSubmitStatementInDeadState() { public void failToSubmitStatementInFailState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = assertThrows( @@ -262,7 +257,7 @@ public void failToSubmitStatementInFailState() { public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -280,7 +275,7 @@ public void newStatementFieldAssert() { public void failToSubmitStatementInDeletedSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // other's delete session client() @@ -298,9 +293,9 @@ public void failToSubmitStatementInDeletedSession() { public void getStatementSuccess() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -313,9 +308,9 @@ public void getStatementSuccess() { public void getStatementNotExist() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(createSessionRequest()); + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); Optional statement = session.get(StatementId.newStatementId()); assertFalse(statement.isPresent()); @@ -333,8 +328,7 @@ public static TestStatement testStatement(Statement st, StateStore stateStore) { public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); - Optional model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -344,8 +338,7 @@ public TestStatement assertSessionState(StatementState expected) { public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); - Optional model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this;