diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 8daf0e9bf6..89d046b3d9 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -5,6 +5,8 @@ package org.opensearch.sql.common.setting; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; + import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -36,7 +38,8 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), - CLUSTER_NAME("cluster.name"); + CLUSTER_NAME("cluster.name"), + SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"); @Getter private final String keyValue; @@ -60,4 +63,9 @@ public static Optional of(String keyValue) { public abstract T getSettingValue(Key key); public abstract List getSettings(); + + /** Helper class */ + public static boolean isSparkExecutionSessionEnabled(Settings settings) { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index b5da4e28e2..cd56e76491 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -311,3 +311,39 @@ SQL query:: "status": 400 } +plugins.query.executionengine.spark.session.enabled +=================================================== + +Description +----------- + +By default, execution engine is executed in job mode. You can enable session mode by this setting. + +1. The default value is false. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.enabled":"true"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "enabled": "true" + } + } + } + } + } + } + } + diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index a9fc77264c..3fbc16d15f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -62,6 +62,50 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } +Execute query in session +------------------------ + +if plugins.query.executionengine.spark.session.enabled is set to true, session based execution is enabled. Under the hood, all queries submitted to the same session will be executed in the same SparkContext. Session is auto closed if not query submission in 10 minutes. + +Async query response include ``sessionId`` indicate the query is executed in session. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "HlbM61kX6MDkAktO", + "sessionId": "1Giy65ZnzNlmsPAm" + } + +User could reuse the session by using ``sessionId`` query parameters. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10", + "sessionId" : "1Giy65ZnzNlmsPAm" + }' + +Sample Response:: + + { + "queryId": "7GC4mHhftiTejvxN", + "sessionId": "1Giy65ZnzNlmsPAm" + } + Async Query Result API ====================================== diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 76bda07607..ecb35afafa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -135,6 +135,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_SESSION_ENABLED_SETTING = + Setting.boolSetting( + Key.SPARK_EXECUTION_SESSION_ENABLED.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -205,6 +212,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.SPARK_EXECUTION_ENGINE_CONFIG, SPARK_EXECUTION_ENGINE_CONFIG, new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_ENABLED, + SPARK_EXECUTION_SESSION_ENABLED_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_ENABLED)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -270,6 +283,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(SPARK_EXECUTION_ENGINE_CONFIG) + .add(SPARK_EXECUTION_SESSION_ENABLED_SETTING) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f3fd043b63..a9a35f6318 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -99,6 +100,8 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; @@ -318,7 +321,11 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), - client); + client, + new SessionManager( + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), + emrServerlessClient, + pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 13db103f4b..7cba2757cc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -65,14 +65,17 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameters())); + sparkExecutionEngineConfig.getSparkSubmitParameters(), + createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), - dispatchQueryResponse.getResultIndex())); - return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); + dispatchQueryResponse.getResultIndex(), + dispatchQueryResponse.getSessionId())); + return new CreateAsyncQueryResponse( + dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); } @Override @@ -81,6 +84,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { + String sessionId = jobMetadata.get().getSessionId(); JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = @@ -90,13 +94,18 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); + JobRunState.SUCCESS.toString(), + sparkSqlFunctionResponseHandle.schema(), + result, + null, + sessionId); } else { return new AsyncQueryExecutionResponse( jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), null, null, - jsonObject.optString(ERROR_FIELD, "")); + jsonObject.optString(ERROR_FIELD, ""), + sessionId); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index d2e54af004..e5d9cffd5f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -19,4 +19,5 @@ public class AsyncQueryExecutionResponse { private final ExecutionEngine.Schema schema; private final List results; private final String error; + private final String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b470ef989f..b80fefa173 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -30,12 +30,15 @@ public class AsyncQueryJobMetadata { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + // optional sessionId. + private String sessionId; public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { this.applicationId = applicationId; this.jobId = jobId; this.isDropIndexQuery = false; this.resultIndex = resultIndex; + this.sessionId = null; } @Override @@ -57,6 +60,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.field("applicationId", metadata.getApplicationId()); builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); builder.field("resultIndex", metadata.getResultIndex()); + builder.field("sessionId", metadata.getSessionId()); builder.endObject(); return builder; } @@ -92,6 +96,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -109,6 +114,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "resultIndex": resultIndex = parser.textOrNull(); break; + case "sessionId": + sessionId = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -116,6 +124,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery, resultIndex); + return new AsyncQueryJobMetadata( + applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 284afcc0a9..1b248eb15d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -21,6 +21,7 @@ public class SparkConstants { public static final String SPARK_SQL_APPLICATION_JAR = "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; + public static final String SPARK_REQUEST_BUFFER_INDEX_NAME = ".query_execution_request"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 347e154885..8d5ae10e91 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -16,6 +16,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import lombok.AllArgsConstructor; import lombok.Getter; @@ -39,6 +40,14 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -69,6 +78,8 @@ public class SparkQueryDispatcher { private Client client; + private SessionManager sessionManager; + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { return handleSQLQuery(dispatchQueryRequest); @@ -111,23 +122,60 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) String error = items.optString(ERROR_FIELD, ""); result.put(ERROR_FIELD, error); } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); + if (asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + StatementState statementState = statement.get().getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + } } return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + if (asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + statement.get().cancel(); + return statementId.getId(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return cancelJobRunResult.getJobRunId(); + } } private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -173,7 +221,7 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -198,8 +246,35 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ tags, false, dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { + Session session; + if (dispatchQueryRequest.getSessionId() != null) { + // get session from request + SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); + Optional createdSession = sessionManager.getSession(sessionId); + if (createdSession.isEmpty()) { + throw new IllegalArgumentException("no session found. " + sessionId); + } + session = createdSession.get(); + } else { + // create session if not exist + session = + sessionManager.createSession( + new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + } + StatementId statementId = + session.submit( + new QueryRequest( + dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + return new DispatchQueryResponse( + statementId.getId(), + false, + dataSourceMetadata.getResultIndex(), + session.getSessionId().getSessionId()); + } else { + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + } } private DispatchQueryResponse handleDropIndexQuery( @@ -229,7 +304,7 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex()); + new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 823a4570ce..6aa28227a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -23,4 +23,7 @@ public class DispatchQueryRequest { /** Optional extra Spark submit parameters to include in final request */ private String extraSparkSubmitParams; + + /** Optional sessionId. */ + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 9ee5f156f2..893446c617 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -9,4 +9,5 @@ public class DispatchQueryResponse { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index a2847cde18..861d906b9b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -13,7 +13,7 @@ public class SessionId { private final String sessionId; public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.random(10, true, true)); + return new SessionId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 217af80caf..c34be7015f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -5,10 +5,12 @@ package org.opensearch.sql.spark.execution.session; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -21,6 +23,7 @@ public class SessionManager { private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; + private final Settings settings; public Session createSession(CreateSessionRequest request) { InteractiveSession session = @@ -47,4 +50,8 @@ public Optional getSession(SessionId sid) { } return Optional.empty(); } + + public boolean isEnabled() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index 4baff71493..d9381ad45f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -13,7 +13,7 @@ public class StatementId { private final String id; public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.random(10, true, true)); + return new StatementId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 8802630d9f..6acf6bc9a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.rest.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_ID; import java.io.IOException; import lombok.Data; @@ -18,6 +19,8 @@ public class CreateAsyncQueryRequest { private String query; private String datasource; private LangType lang; + // optional sessionId + private String sessionId; public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.query = Validate.notNull(query, "Query can't be null"); @@ -25,11 +28,19 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.lang = Validate.notNull(lang, "lang can't be null"); } + public CreateAsyncQueryRequest(String query, String datasource, LangType lang, String sessionId) { + this.query = Validate.notNull(query, "Query can't be null"); + this.datasource = Validate.notNull(datasource, "Datasource can't be null"); + this.lang = Validate.notNull(lang, "lang can't be null"); + this.sessionId = sessionId; + } + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; String datasource = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -41,10 +52,12 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) lang = LangType.fromString(langString); } else if (fieldName.equals("datasource")) { datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateAsyncQueryRequest(query, datasource, lang); + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java index 8cfe57c2a6..2f918308c4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java @@ -12,4 +12,6 @@ @AllArgsConstructor public class CreateAsyncQueryResponse { private String queryId; + // optional sessionId + private String sessionId; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 01bccd9030..0d4e280b61 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -78,7 +78,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) @@ -107,7 +107,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index abae0377a2..3a0d8fc56d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -16,4 +16,6 @@ public class TestConstants { public static final String EMRS_JOB_NAME = "job_name"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; + public static final String MOCK_SESSION_ID = "s-0123456"; + public static final String MOCK_STATEMENT_ID = "st-0123456"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 8c0ecb2ea2..58fe626dae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -8,8 +8,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -18,6 +22,8 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_STATEMENT_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; @@ -34,6 +40,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; @@ -58,6 +65,12 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; @@ -78,6 +91,12 @@ public class SparkQueryDispatcherTest { @Mock private FlintIndexMetadata flintIndexMetadata; + @Mock private SessionManager sessionManager; + + @Mock private Session session; + + @Mock private Statement statement; + private SparkQueryDispatcher sparkQueryDispatcher; @Captor ArgumentCaptor startJobRequestArgumentCaptor; @@ -91,7 +110,8 @@ void setUp() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); } @Test @@ -256,6 +276,84 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchSelectQueryCreateNewSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(session).when(sessionManager).createSession(any()); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).getSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryReuseSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, MOCK_SESSION_ID); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)) + .when(sessionManager) + .getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryInvalidSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testDispatchSelectQueryFailedCreateSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doThrow(RuntimeException.class).when(sessionManager).createSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + Assertions.assertThrows( + RuntimeException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + } + @Test void testDispatchIndexQuery() { HashMap tags = new HashMap<>(); @@ -544,6 +642,68 @@ void testCancelJob() { Assertions.assertEquals(EMR_JOB_ID, jobId); } + @Test + void testCancelQueryWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doNothing().when(statement).cancel(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + verify(statement, times(1)).cancel(); + Assertions.assertEquals(MOCK_STATEMENT_ID, queryId); + } + + @Test + void testCancelQueryWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(new SessionId("invalid")); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(session); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithInvalidStatementId() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(statement); + Assertions.assertEquals( + "no statement found. " + new StatementId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithNoSessionId() { + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + @Test void testGetQueryResponse() { when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) @@ -558,6 +718,60 @@ void testGetQueryResponse() { Assertions.assertEquals("PENDING", result.get("status")); } + @Test + void testGetQueryResponseWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doReturn(StatementState.WAITING).when(statement).getStatementState(); + + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals("waiting", result.get("status")); + } + + @Test + void testGetQueryResponseWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no session found. " + new SessionId(MOCK_SESSION_ID), exception.getMessage()); + } + + @Test + void testGetQueryResponseWithStatementNotExist() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.empty()).when(session).get(any()); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); + } + @Test void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = @@ -567,7 +781,8 @@ void testGetQueryResponseWithSuccess() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); JSONObject queryResult = new JSONObject(); Map resultMap = new HashMap<>(); resultMap.put(STATUS_FIELD, "SUCCESS"); @@ -604,14 +819,15 @@ void testGetQueryResponseOfDropIndex() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); String jobId = new SparkQueryDispatcher.DropIndexResult(JobRunState.SUCCESS.toString()).toJobId(); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -978,6 +1194,24 @@ private DispatchQueryRequest constructDispatchQueryRequest( langType, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - extraParameters); + extraParameters, + null); + } + + private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + null, + sessionId); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( + String queryId, String sessionId) { + return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 488252d05a..429c970365 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -114,7 +115,7 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); TestSession testSession = testSession(session, stateStore); @@ -123,7 +124,8 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Session session = sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); @@ -134,7 +136,8 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); assertTrue(managerSession.isEmpty()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 95b85613be..4374bd4f11 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,25 +5,48 @@ package org.opensearch.sql.spark.execution.session; -import org.junit.After; -import org.junit.Before; -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; -class SessionManagerTest extends OpenSearchSingleNodeTestCase { - private static final String indexName = "mockindex"; +@ExtendWith(MockitoExtension.class) +public class SessionManagerTest { + @Mock private StateStore stateStore; + @Mock private EMRServerlessClient emrClient; - private StateStore stateStore; + @Test + public void sessionEnable() { + Assertions.assertTrue( + new SessionManager(stateStore, emrClient, sessionSetting(true)).isEnabled()); + Assertions.assertFalse( + new SessionManager(stateStore, emrClient, sessionSetting(false)).isEnabled()); + } - @Before - public void setup() { - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + public static Settings sessionSetting(boolean enabled) { + Map settings = new HashMap<>(); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_ENABLED, enabled); + return settings(settings); } - @After - public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + public static Settings settings(Map settings) { + return new Settings() { + @Override + public T getSettingValue(Key key) { + return (T) settings.get(key); + } + + @Override + public List getSettings() { + return (List) settings; + } + }; } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 331955e14e..214bcb8258 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; @@ -196,7 +197,7 @@ public void cancelRunningStatementFailed() { @Test public void submitStatementInRunningSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running @@ -209,7 +210,7 @@ public void submitStatementInRunningSession() { @Test public void submitStatementInNotStartedState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); @@ -219,7 +220,7 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); @@ -237,7 +238,7 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); @@ -255,7 +256,7 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -273,7 +274,7 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // other's delete session @@ -291,7 +292,7 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); @@ -306,7 +307,7 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java new file mode 100644 index 0000000000..dd634d6055 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import java.io.IOException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class CreateAsyncQueryRequestTest { + + @Test + public void fromXContent() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("my_glue", queryRequest.getDatasource()); + Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); + Assertions.assertEquals("select 1", queryRequest.getQuery()); + } + + @Test + public void fromXContentWithSessionId() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\",\n" + + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); + } + + private XContentParser xContentParser(String request) throws IOException { + return XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, request); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 8599e4b88e..36060d3850 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -61,7 +62,7 @@ public void testDoExecute() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) - .thenReturn(new CreateAsyncQueryResponse("123")); + .thenReturn(new CreateAsyncQueryResponse("123", null)); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = @@ -70,6 +71,24 @@ public void testDoExecute() { "{\n" + " \"queryId\": \"123\"\n" + "}", createAsyncQueryActionResponse.getResult()); } + @Test + public void testDoExecuteWithSessionId() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest( + "source = my_glue.default.alb_logs", "my_glue", LangType.SQL, MOCK_SESSION_ID); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + .thenReturn(new CreateAsyncQueryResponse("123", MOCK_SESSION_ID)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateAsyncQueryActionResponse createAsyncQueryActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"queryId\": \"123\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", + createAsyncQueryActionResponse.getResult()); + } + @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 21a213c7c2..34f10b0083 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -63,7 +63,7 @@ public void setUp() { public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -89,6 +89,7 @@ public void testDoExecuteWithSuccessResponse() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener);