From b86cf2f9373307a57063edc2d96eafb4e37bbc8b Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 12 Oct 2023 09:12:51 -0700 Subject: [PATCH 01/16] release-notes-2.11 (#2284) * add 2.11 release notes --------- Signed-off-by: YANGDB --- .../opensearch-sql.release-notes-2.11.0.0.md | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 release-notes/opensearch-sql.release-notes-2.11.0.0.md diff --git a/release-notes/opensearch-sql.release-notes-2.11.0.0.md b/release-notes/opensearch-sql.release-notes-2.11.0.0.md new file mode 100644 index 0000000000..a560d5c8dd --- /dev/null +++ b/release-notes/opensearch-sql.release-notes-2.11.0.0.md @@ -0,0 +1,55 @@ +Compatible with OpenSearch and OpenSearch Dashboards Version 2.11.0 + +### Features + +### Enhancements +* Enable PPL lang and add datasource to async query API in https://github.com/opensearch-project/sql/pull/2195 +* Refactor Flint Auth in https://github.com/opensearch-project/sql/pull/2201 +* Add conf for spark structured streaming job in https://github.com/opensearch-project/sql/pull/2203 +* Submit long running job only when auto_refresh = false in https://github.com/opensearch-project/sql/pull/2209 +* Bug Fix, handle DESC TABLE response in https://github.com/opensearch-project/sql/pull/2213 +* Drop Index Implementation in https://github.com/opensearch-project/sql/pull/2217 +* Enable PPL Queries in https://github.com/opensearch-project/sql/pull/2223 +* Read extra Spark submit parameters from cluster settings in https://github.com/opensearch-project/sql/pull/2236 +* Spark Execution Engine Config Refactor in https://github.com/opensearch-project/sql/pull/2266 +* Provide auth.type and auth.role_arn paramters in GET Datasource API response. in https://github.com/opensearch-project/sql/pull/2283 +* Add support for `date_nanos` and tests. (#337) in https://github.com/opensearch-project/sql/pull/2020 +* Applied formatting improvements to Antlr files based on spotless changes (#2017) by @MitchellGale in https://github.com/opensearch-project/sql/pull/2023 +* Revert "Guarantee datasource read api is strong consistent read (#1815)" in https://github.com/opensearch-project/sql/pull/2031 +* Add _primary preference only for segment replication enabled indices in https://github.com/opensearch-project/sql/pull/2045 +* Changed allowlist config to denylist ip config for datasource uri hosts in https://github.com/opensearch-project/sql/pull/2058 + +### Bug Fixes +* fix broken link for connectors doc in https://github.com/opensearch-project/sql/pull/2199 +* Fix response codes returned by JSON formatting them in https://github.com/opensearch-project/sql/pull/2200 +* Bug fix, datasource API should be case sensitive in https://github.com/opensearch-project/sql/pull/2202 +* Minor fix in dropping covering index in https://github.com/opensearch-project/sql/pull/2240 +* Fix Unit tests for FlintIndexReader in https://github.com/opensearch-project/sql/pull/2242 +* Bug Fix , delete OpenSearch index when DROP INDEX in https://github.com/opensearch-project/sql/pull/2252 +* Correctly Set query status in https://github.com/opensearch-project/sql/pull/2232 +* Exclude generated files from spotless in https://github.com/opensearch-project/sql/pull/2024 +* Fix mockito core conflict. in https://github.com/opensearch-project/sql/pull/2131 +* Fix `ASCII` function and groom UT for text functions. (#301) in https://github.com/opensearch-project/sql/pull/2029 +* Fixed response codes For Requests With security exception. in https://github.com/opensearch-project/sql/pull/2036 + +### Documentation +* Datasource description in https://github.com/opensearch-project/sql/pull/2138 +* Add documentation for S3GlueConnector. in https://github.com/opensearch-project/sql/pull/2234 + +### Infrastructure +* bump aws-encryption-sdk-java to 1.71 in https://github.com/opensearch-project/sql/pull/2057 +* Run IT tests with security plugin (#335) #1986 by @MitchellGale in https://github.com/opensearch-project/sql/pull/2022 + +### Refactoring +* Merging Async Query APIs feature branch into main. in https://github.com/opensearch-project/sql/pull/2163 +* Removed Domain Validation in https://github.com/opensearch-project/sql/pull/2136 +* Check for existence of security plugin in https://github.com/opensearch-project/sql/pull/2069 +* Always use snapshot version for security plugin download in https://github.com/opensearch-project/sql/pull/2061 +* Add customized result index in data source etc in https://github.com/opensearch-project/sql/pull/2220 + +### Security +* bump okhttp to 4.10.0 (#2043) by @joshuali925 in https://github.com/opensearch-project/sql/pull/2044 +* bump okio to 3.4.0 by @joshuali925 in https://github.com/opensearch-project/sql/pull/2047 + +--- +**Full Changelog**: https://github.com/opensearch-project/sql/compare/2.3.0.0...v.2.11.0.0 \ No newline at end of file From f856cb3f15eb079e554fd1301a4bcfd7d5fefc0d Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 13 Oct 2023 14:34:12 -0700 Subject: [PATCH 02/16] add InteractiveSession and SessionManager (#2290) * add InteractiveSession and SessionManager Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- spark/build.gradle | 39 +++- .../session/CreateSessionRequest.java | 15 ++ .../execution/session/InteractiveSession.java | 61 +++++ .../sql/spark/execution/session/Session.java | 19 ++ .../spark/execution/session/SessionId.java | 23 ++ .../execution/session/SessionManager.java | 50 ++++ .../spark/execution/session/SessionModel.java | 143 ++++++++++++ .../spark/execution/session/SessionState.java | 36 +++ .../spark/execution/session/SessionType.java | 33 +++ .../statestore/SessionStateStore.java | 87 +++++++ .../session/InteractiveSessionTest.java | 213 ++++++++++++++++++ .../execution/session/SessionManagerTest.java | 38 ++++ .../execution/session/SessionStateTest.java | 20 ++ .../execution/session/SessionTypeTest.java | 20 ++ .../statestore/SessionStateStoreTest.java | 42 ++++ 15 files changed, 834 insertions(+), 5 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java diff --git a/spark/build.gradle b/spark/build.gradle index c06b5b6ecf..c2c925ecaf 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -52,15 +52,38 @@ dependencies { api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' - testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation(platform("org.junit:junit-bom:5.6.2")) + + testImplementation('org.junit.jupiter:junit-jupiter') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' - testImplementation 'junit:junit:4.13.1' - testImplementation "org.opensearch.test:framework:${opensearch_version}" + + testCompileOnly('junit:junit:4.13.1') { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.vintage:junit-vintage-engine") { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.platform:junit-platform-launcher") { + because 'allows tests to run from IDEs that bundle older version of launcher' + } + testImplementation("org.opensearch.test:framework:${opensearch_version}") } test { - useJUnitPlatform() + useJUnitPlatform { + includeEngines("junit-jupiter") + } + testLogging { + events "failed" + exceptionFormat "full" + } +} +task junit4(type: Test) { + useJUnitPlatform { + includeEngines("junit-vintage") + } + systemProperty 'tests.security.manager', 'false' testLogging { events "failed" exceptionFormat "full" @@ -68,6 +91,8 @@ test { } jacocoTestReport { + dependsOn test, junit4 + executionData test, junit4 reports { html.enabled true xml.enabled true @@ -78,9 +103,10 @@ jacocoTestReport { })) } } -test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { + dependsOn test, junit4 + executionData test, junit4 violationRules { rule { element = 'CLASS' @@ -92,6 +118,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.asyncquery.exceptions.*', 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', + // ignore because XContext IOException + 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java new file mode 100644 index 0000000000..17e3346248 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.opensearch.sql.spark.client.StartJobRequest; + +@Data +public class CreateSessionRequest { + private final StartJobRequest startJobRequest; + private final String datasourceName; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java new file mode 100644 index 0000000000..620e46b9be --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; + +import java.util.Optional; +import lombok.Builder; +import lombok.Getter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Interactive session. + * + *

ENTRY_STATE: not_started + */ +@Getter +@Builder +public class InteractiveSession implements Session { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final SessionStateStore sessionStateStore; + private final EMRServerlessClient serverlessClient; + + private SessionModel sessionModel; + + @Override + public void open(CreateSessionRequest createSessionRequest) { + try { + String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); + String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); + + sessionModel = + initInteractiveSession( + applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + sessionStateStore.create(sessionModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "session already exist. " + sessionId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + @Override + public void close() { + Optional model = sessionStateStore.get(sessionModel.getSessionId()); + if (model.isEmpty()) { + throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + } else { + serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java new file mode 100644 index 0000000000..ec9775e60a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +/** Session define the statement execution context. Each session is binding to one Spark Job. */ +public interface Session { + /** open session. */ + void open(CreateSessionRequest createSessionRequest); + + /** close session. */ + void close(); + + SessionModel getSessionModel(); + + SessionId getSessionId(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java new file mode 100644 index 0000000000..a2847cde18 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class SessionId { + private final String sessionId; + + public static SessionId newSessionId() { + return new SessionId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "sessionId=" + sessionId; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java new file mode 100644 index 0000000000..3d0916bac8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Singleton Class + * + *

todo. add Session cache and Session sweeper. + */ +@RequiredArgsConstructor +public class SessionManager { + private final SessionStateStore stateStore; + private final EMRServerlessClient emrServerlessClient; + + public Session createSession(CreateSessionRequest request) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .build(); + session.open(request); + return session; + } + + public Optional getSession(SessionId sid) { + Optional model = stateStore.get(sid); + if (model.isPresent()) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sid) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .sessionModel(model.get()) + .build(); + return Optional.ofNullable(session); + } + return Optional.empty(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java new file mode 100644 index 0000000000..656f0ec8ce --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; + +/** Session data in flint.ql.sessions index. */ +@Data +@Builder +public class SessionModel implements ToXContentObject { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String SESSION_TYPE = "sessionType"; + public static final String SESSION_ID = "sessionId"; + public static final String SESSION_STATE = "state"; + public static final String DATASOURCE_NAME = "dataSourceName"; + public static final String LAST_UPDATE_TIME = "lastUpdateTime"; + public static final String APPLICATION_ID = "applicationId"; + public static final String JOB_ID = "jobId"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String SESSION_DOC_TYPE = "session"; + + private final String version; + private final SessionType sessionType; + private final SessionId sessionId; + private final SessionState sessionState; + private final String applicationId; + private final String jobId; + private final String datasourceName; + private final String error; + private final long lastUpdateTime; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, SESSION_DOC_TYPE) + .field(SESSION_TYPE, sessionType.getSessionType()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(SESSION_STATE, sessionState.getSessionState()) + .field(DATASOURCE_NAME, datasourceName) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LAST_UPDATE_TIME, lastUpdateTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(copy.sessionState) + .datasourceName(copy.datasourceName) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + SessionModelBuilder builder = new SessionModelBuilder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case SESSION_TYPE: + builder.sessionType(SessionType.fromString(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case SESSION_STATE: + builder.sessionState(SessionState.fromString(parser.text())); + break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; + case ERROR: + builder.error(parser.text()); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LAST_UPDATE_TIME: + builder.lastUpdateTime(parser.longValue()); + break; + case TYPE: + // do nothing. + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static SessionModel initInteractiveSession( + String applicationId, String jobId, SessionId sid, String datasourceName) { + return builder() + .version("1.0") + .sessionType(INTERACTIVE) + .sessionId(sid) + .sessionState(NOT_STARTED) + .datasourceName(datasourceName) + .applicationId(applicationId) + .jobId(jobId) + .error(UNKNOWN) + .lastUpdateTime(System.currentTimeMillis()) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java new file mode 100644 index 0000000000..509d5105e9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionState { + NOT_STARTED("not_started"), + RUNNING("running"), + DEAD("dead"), + FAIL("fail"); + + private final String sessionState; + + SessionState(String sessionState) { + this.sessionState = sessionState; + } + + private static Map STATES = + Arrays.stream(SessionState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid session state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java new file mode 100644 index 0000000000..dd179a1dc5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionType { + INTERACTIVE("interactive"); + + private final String sessionType; + + SessionType(String sessionType) { + this.sessionType = sessionType; + } + + private static Map TYPES = + Arrays.stream(SessionType.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionType fromString(String key) { + if (TYPES.containsKey(key)) { + return TYPES.get(key); + } + throw new IllegalArgumentException("Invalid session type: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java new file mode 100644 index 0000000000..6ddce55360 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@RequiredArgsConstructor +public class SessionStateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + public SessionModel create(SessionModel session) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(session.getSessionId().getSessionId()) + .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(session.getSeqNo()) + .setIfPrimaryTerm(session.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", session.getSessionId()); + return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + session.getSessionId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public Optional get(SessionId sid) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + SessionModel.fromXContent( + parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java new file mode 100644 index 0000000000..53dc211ded --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +/** mock-maker-inline does not work with OpenSearchTestCase. */ +public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private TestEMRServerlessClient emrsClient; + private StartJobRequest startJobRequest; + private SessionStateStore stateStore; + + @Before + public void setup() { + emrsClient = new TestEMRServerlessClient(); + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openCloseSession() { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(SessionId.newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + + // open session + TestSession testSession = testSession(session, stateStore); + testSession + .open(new CreateSessionRequest(startJobRequest, "datasource")) + .assertSessionState(NOT_STARTED) + .assertAppId("appId") + .assertJobId("jobId"); + emrsClient.startJobRunCalled(1); + + // close session + testSession.close(); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void openSessionFailedConflict() { + SessionId sessionId = new SessionId("duplicate-session-id"); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); + + InteractiveSession duplicateSession = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); + assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + } + + @Test + public void closeNotExistSession() { + SessionId sessionId = SessionId.newSessionId(); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); + + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); + assertEquals("session not exist. " + sessionId, exception.getMessage()); + emrsClient.cancelJobRunCalled(0); + } + + @Test + public void sessionManagerCreateSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + TestSession testSession = testSession(session, stateStore); + testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + } + + @Test + public void sessionManagerGetSession() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + Session session = + sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + Optional managerSession = sessionManager.getSession(session.getSessionId()); + assertTrue(managerSession.isPresent()); + assertEquals(session.getSessionId(), managerSession.get().getSessionId()); + } + + @Test + public void sessionManagerGetSessionNotExist() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + + Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + assertTrue(managerSession.isEmpty()); + } + + @RequiredArgsConstructor + static class TestSession { + private final Session session; + private final SessionStateStore stateStore; + + public static TestSession testSession(Session session, SessionStateStore stateStore) { + return new TestSession(session, stateStore); + } + + public TestSession assertSessionState(SessionState expected) { + assertEquals(expected, session.getSessionModel().getSessionState()); + + Optional sessionStoreState = + stateStore.get(session.getSessionModel().getSessionId()); + assertTrue(sessionStoreState.isPresent()); + assertEquals(expected, sessionStoreState.get().getSessionState()); + + return this; + } + + public TestSession assertAppId(String expected) { + assertEquals(expected, session.getSessionModel().getApplicationId()); + return this; + } + + public TestSession assertJobId(String expected) { + assertEquals(expected, session.getSessionModel().getJobId()); + return this; + } + + public TestSession open(CreateSessionRequest req) { + session.open(req); + return this; + } + + public TestSession close() { + session.close(); + return this; + } + } + + static class TestEMRServerlessClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + return null; + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return null; + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java new file mode 100644 index 0000000000..d35105f787 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.After; +import org.junit.Before; +import org.mockito.MockMakers; +import org.mockito.MockSettings; +import org.mockito.Mockito; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +class SessionManagerTest extends OpenSearchSingleNodeTestCase { + private static final String indexName = "mockindex"; + + // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. + private static final MockSettings mockSettings = + Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); + + private SessionStateStore stateStore; + + @Before + public void setup() { + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java new file mode 100644 index 0000000000..a987c80d59 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionStateTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionState.fromString("invalid")); + assertEquals("Invalid session state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java new file mode 100644 index 0000000000..a2ab43e709 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionTypeTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionType.fromString("invalid")); + assertEquals("Invalid session type: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java new file mode 100644 index 0000000000..9c779555d7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@ExtendWith(MockitoExtension.class) +class SessionStateStoreTest { + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + @Mock private IndexResponse indexResponse; + + @Test + public void createWithException() { + when(client.index(any()).actionGet()).thenReturn(indexResponse); + doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); + SessionModel sessionModel = + SessionModel.initInteractiveSession( + "appId", "jobId", SessionId.newSessionId(), "datasource"); + SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); + + assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); + } +} From b76a15e9db35f259f4a5a3e4567ba8c7b84bc962 Mon Sep 17 00:00:00 2001 From: Derek Ho Date: Mon, 16 Oct 2023 15:28:51 -0400 Subject: [PATCH 03/16] Bump bwc verison to 2.12 (#2292) Signed-off-by: Derek Ho --- integ-test/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 6925cb9101..f2e70d9908 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -42,7 +42,7 @@ apply plugin: 'java' apply plugin: 'io.freefair.lombok' apply plugin: 'com.wiredforcode.spawn' -String baseVersion = "2.11.0" +String baseVersion = "2.12.0" String bwcVersion = baseVersion + ".0"; String baseName = "sqlBwcCluster" String bwcFilePath = "src/test/resources/bwc/" From 501cf915d16215fcf4c5df451ea64b5e31abe3c4 Mon Sep 17 00:00:00 2001 From: Derek Ho Date: Tue, 17 Oct 2023 11:00:07 -0400 Subject: [PATCH 04/16] Upgrade json (#2307) Signed-off-by: Derek Ho --- legacy/build.gradle | 2 +- opensearch/build.gradle | 2 +- ppl/build.gradle | 2 +- prometheus/build.gradle | 2 +- spark/build.gradle | 2 +- sql/build.gradle | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/legacy/build.gradle b/legacy/build.gradle index fc985989e5..ca20476610 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -108,7 +108,7 @@ dependencies { } } implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" // add geo module as dependency. https://github.com/opensearch-project/OpenSearch/pull/4180/. diff --git a/opensearch/build.gradle b/opensearch/build.gradle index 34b5c3f452..c9087bca49 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -37,7 +37,7 @@ dependencies { implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' compileOnly group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" implementation group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" diff --git a/ppl/build.gradle b/ppl/build.gradle index a798b3f4b0..04ad71ced6 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -49,7 +49,7 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - api group: 'org.json', name: 'json', version: '20230227' + api group: 'org.json', name: 'json', version: '20231013' implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' api project(':common') api project(':core') diff --git a/prometheus/build.gradle b/prometheus/build.gradle index f8c10c7f6b..c2878ab1b4 100644 --- a/prometheus/build.gradle +++ b/prometheus/build.gradle @@ -22,7 +22,7 @@ dependencies { implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" - implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.json', name: 'json', version: '20231013' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/spark/build.gradle b/spark/build.gradle index c2c925ecaf..49ff96bec5 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -47,7 +47,7 @@ dependencies { implementation project(':datasources') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" - implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.json', name: 'json', version: '20231013' api group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.545' api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' diff --git a/sql/build.gradle b/sql/build.gradle index 2984158e57..c9b46d38f1 100644 --- a/sql/build.gradle +++ b/sql/build.gradle @@ -47,7 +47,7 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - implementation group: 'org.json', name: 'json', version:'20230227' + implementation group: 'org.json', name: 'json', version:'20231013' implementation project(':common') implementation project(':core') api project(':protocol') From 69572c8cca278a500db7710fb415cc58a2589c78 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Tue, 17 Oct 2023 08:56:02 -0700 Subject: [PATCH 05/16] Minor Refactoring (#2308) Signed-off-by: Vamsi Manohar --- spark/src/main/antlr/SqlBaseParser.g4 | 2 +- .../sql/spark/client/StartJobRequest.java | 2 + .../dispatcher/SparkQueryDispatcherTest.java | 329 +++++++++--------- 3 files changed, 158 insertions(+), 175 deletions(-) diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 6a6d39e96c..77a9108e06 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -967,7 +967,6 @@ primaryExpression | qualifiedName DOT ASTERISK #star | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression - | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? @@ -1196,6 +1195,7 @@ qualifiedNameList functionName : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | identFunc=IDENTIFIER_KW // IDENTIFIER itself is also a valid function name. | qualifiedName | FILTER | LEFT diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index c4382239a1..f57c8facee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -7,12 +7,14 @@ import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; /** * This POJO carries all the fields required for emr serverless job submission. Used as model in * {@link EMRServerlessClient} interface. */ @Data +@EqualsAndHashCode public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ab9761da36..8c0ecb2ea2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -41,6 +41,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -78,6 +80,8 @@ public class SparkQueryDispatcherTest { private SparkQueryDispatcher sparkQueryDispatcher; + @Captor ArgumentCaptor startJobRequestArgumentCaptor; + @BeforeEach void setUp() { sparkQueryDispatcher = @@ -96,19 +100,21 @@ void testDispatchSelectQuery() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -125,23 +131,18 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -153,20 +154,22 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "basicauth", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); + put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -183,24 +186,18 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -212,18 +209,20 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "noauth", + new HashMap<>() { + { + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -240,22 +239,18 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -272,20 +267,22 @@ void testDispatchIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, true, any()))) @@ -302,24 +299,18 @@ void testDispatchIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -332,19 +323,21 @@ void testDispatchWithPPLQuery() { tags.put("cluster", TEST_CLUSTER_NAME); String query = "source = my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -361,23 +354,18 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -390,19 +378,21 @@ void testDispatchQueryWithoutATableAndDataSourceName() { tags.put("cluster", TEST_CLUSTER_NAME); String query = "show tables"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -419,23 +409,18 @@ void testDispatchQueryWithoutATableAndDataSourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -453,20 +438,22 @@ void testDispatchIndexQueryWithoutADatasourceName() { String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, true, any()))) @@ -483,24 +470,18 @@ void testDispatchIndexQueryWithoutADatasourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -905,8 +886,8 @@ private String constructExpectedSparkSubmitParameterString( + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " - + authParamConfigBuilder - + " --conf spark.flint.datasource.name=my_glue "; + + " --conf spark.flint.datasource.name=my_glue " + + authParamConfigBuilder; } private String withStructuredStreaming(String parameters) { From 297e26f9622e66cf01c777d229e9b13dbc19525d Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 17 Oct 2023 09:20:37 -0700 Subject: [PATCH 06/16] Add Statement (#2294) * add InteractiveSession and SessionManager Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * fix format Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- spark/build.gradle | 5 +- .../execution/session/InteractiveSession.java | 73 +++- .../sql/spark/execution/session/Session.java | 21 ++ .../execution/session/SessionManager.java | 10 +- .../spark/execution/session/SessionModel.java | 30 +- .../spark/execution/session/SessionState.java | 4 + .../execution/statement/QueryRequest.java | 15 + .../spark/execution/statement/Statement.java | 85 +++++ .../execution/statement/StatementId.java | 23 ++ .../execution/statement/StatementModel.java | 194 ++++++++++ .../execution/statement/StatementState.java | 38 ++ .../statestore/SessionStateStore.java | 87 ----- .../execution/statestore/StateModel.java | 30 ++ .../execution/statestore/StateStore.java | 149 ++++++++ .../session/InteractiveSessionTest.java | 27 +- .../execution/session/SessionManagerTest.java | 15 +- .../statement/StatementStateTest.java | 20 + .../execution/statement/StatementTest.java | 356 ++++++++++++++++++ .../statestore/SessionStateStoreTest.java | 42 --- 19 files changed, 1055 insertions(+), 169 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java diff --git a/spark/build.gradle b/spark/build.gradle index 49ff96bec5..15f1e200e0 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -119,8 +119,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', // ignore because XContext IOException - 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', - 'org.opensearch.sql.spark.execution.session.SessionModel' + 'org.opensearch.sql.spark.execution.statestore.StateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel', + 'org.opensearch.sql.spark.execution.statement.StatementModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 620e46b9be..e33ef4245a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,6 +6,10 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; +import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -14,7 +18,11 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; /** * Interactive session. @@ -27,9 +35,8 @@ public class InteractiveSession implements Session { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; - private final SessionStateStore sessionStateStore; + private final StateStore stateStore; private final EMRServerlessClient serverlessClient; - private SessionModel sessionModel; @Override @@ -41,7 +48,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - sessionStateStore.create(sessionModel); + createSession(stateStore).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -49,13 +56,67 @@ public void open(CreateSessionRequest createSessionRequest) { } } + /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = sessionStateStore.get(sessionModel.getSessionId()); + Optional model = getSession(stateStore).apply(sessionModel.getId()); if (model.isEmpty()) { - throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); } } + + /** Submit statement. If submit successfully, Statement in waiting state. */ + public StatementId submit(QueryRequest request) { + Optional model = getSession(stateStore).apply(sessionModel.getId()); + if (model.isEmpty()) { + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + } else { + sessionModel = model.get(); + if (!END_STATE.contains(sessionModel.getSessionState())) { + StatementId statementId = newStatementId(); + Statement st = + Statement.builder() + .sessionId(sessionId) + .applicationId(sessionModel.getApplicationId()) + .jobId(sessionModel.getJobId()) + .stateStore(stateStore) + .statementId(statementId) + .langType(LangType.SQL) + .query(request.getQuery()) + .queryId(statementId.getId()) + .build(); + st.open(); + return statementId; + } else { + String errMsg = + String.format( + "can't submit statement, session should not be in end state, " + + "current session state is: %s", + sessionModel.getSessionState().getSessionState()); + LOG.debug(errMsg); + throw new IllegalStateException(errMsg); + } + } + } + + @Override + public Optional get(StatementId stID) { + return StateStore.getStatement(stateStore) + .apply(stID.getId()) + .map( + model -> + Statement.builder() + .sessionId(sessionId) + .applicationId(model.getApplicationId()) + .jobId(model.getJobId()) + .statementId(model.getStatementId()) + .langType(model.getLangType()) + .query(model.getQuery()) + .queryId(model.getQueryId()) + .stateStore(stateStore) + .statementModel(model) + .build()); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index ec9775e60a..4d919d5e2e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -5,6 +5,11 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Optional; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; + /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ @@ -13,6 +18,22 @@ public interface Session { /** close session. */ void close(); + /** + * submit {@link QueryRequest}. + * + * @param request {@link QueryRequest} + * @return {@link StatementId} + */ + StatementId submit(QueryRequest request); + + /** + * get {@link Statement}. + * + * @param stID {@link StatementId} + * @return {@link Statement} + */ + Optional get(StatementId stID); + SessionModel getSessionModel(); SessionId getSessionId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 3d0916bac8..217af80caf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** * Singleton Class @@ -19,14 +19,14 @@ */ @RequiredArgsConstructor public class SessionManager { - private final SessionStateStore stateStore; + private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); session.open(request); @@ -34,12 +34,12 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = stateStore.get(sid); + Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .sessionModel(model.get()) .build(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index 656f0ec8ce..806cdb083e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -12,16 +12,16 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** Session data in flint.ql.sessions index. */ @Data @Builder -public class SessionModel implements ToXContentObject { +public class SessionModel extends StateModel { public static final String VERSION = "version"; public static final String TYPE = "type"; public static final String SESSION_TYPE = "sessionType"; @@ -73,6 +73,27 @@ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(copy.sessionState) .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static SessionModel copyWithState( + SessionModel copy, SessionState state, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(state) + .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) .seqNo(seqNo) .primaryTerm(primaryTerm) .build(); @@ -140,4 +161,9 @@ public static SessionModel initInteractiveSession( .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } + + @Override + public String getId() { + return sessionId.getSessionId(); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index 509d5105e9..a4da957f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.session; +import com.google.common.collect.ImmutableList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -17,6 +19,8 @@ public enum SessionState { DEAD("dead"), FAIL("fail"); + public static List END_STATE = ImmutableList.of(DEAD, FAIL); + private final String sessionState; SessionState(String sessionState) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java new file mode 100644 index 0000000000..10061404ca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class QueryRequest { + private final LangType langType; + private final String query; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java new file mode 100644 index 0000000000..8fcedb5fca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement represent query to execute in session. One statement map to one session. */ +@Getter +@Builder +public class Statement { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final StatementId statementId; + private final LangType langType; + private final String query; + private final String queryId; + private final StateStore stateStore; + + @Setter private StatementModel statementModel; + + /** Open a statement. */ + public void open() { + try { + statementModel = + submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); + statementModel = createStatement(stateStore).apply(statementModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "statement already exist. " + statementId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + /** Cancel a statement. */ + public void cancel() { + if (statementModel.getStatementState().equals(StatementState.RUNNING)) { + String errorMsg = + String.format("can't cancel statement in waiting state. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + try { + this.statementModel = + updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + } catch (DocumentMissingException e) { + String errorMsg = + String.format("cancel statement failed. no statement found. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } catch (VersionConflictEngineException e) { + this.statementModel = + getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + String errorMsg = + String.format( + "cancel statement failed. current statementState: %s " + "statement: %s.", + this.statementModel.getStatementState(), statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + public StatementState getStatementState() { + return statementModel.getStatementState(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java new file mode 100644 index 0000000000..4baff71493 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class StatementId { + private final String id; + + public static StatementId newStatementId() { + return new StatementId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "statementId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java new file mode 100644 index 0000000000..c7f681c541 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateModel; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement data in flint.ql.sessions index. */ +@Data +@Builder +public class StatementModel extends StateModel { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String STATEMENT_STATE = "state"; + public static final String STATEMENT_ID = "statementId"; + public static final String SESSION_ID = "sessionId"; + public static final String LANG = "lang"; + public static final String QUERY = "query"; + public static final String QUERY_ID = "queryId"; + public static final String SUBMIT_TIME = "submitTime"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String STATEMENT_DOC_TYPE = "statement"; + + private final String version; + private final StatementState statementState; + private final StatementId statementId; + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final LangType langType; + private final String query; + private final String queryId; + private final long submitTime; + private final String error; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, STATEMENT_DOC_TYPE) + .field(STATEMENT_STATE, statementState.getState()) + .field(STATEMENT_ID, statementId.getId()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LANG, langType.getText()) + .field(QUERY, query) + .field(QUERY_ID, queryId) + .field(SUBMIT_TIME, submitTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(copy.statementState) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static StatementModel copyWithState( + StatementModel copy, StatementState state, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(state) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static StatementModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + StatementModel.StatementModelBuilder builder = StatementModel.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case TYPE: + // do nothing + break; + case STATEMENT_STATE: + builder.statementState(StatementState.fromString(parser.text())); + break; + case STATEMENT_ID: + builder.statementId(new StatementId(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LANG: + builder.langType(LangType.fromString(parser.text())); + break; + case QUERY: + builder.query(parser.text()); + break; + case QUERY_ID: + builder.queryId(parser.text()); + break; + case SUBMIT_TIME: + builder.submitTime(parser.longValue()); + break; + case ERROR: + builder.error(parser.text()); + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static StatementModel submitStatement( + SessionId sid, + String applicationId, + String jobId, + StatementId statementId, + LangType langType, + String query, + String queryId) { + return builder() + .version("1.0") + .statementState(WAITING) + .statementId(statementId) + .sessionId(sid) + .applicationId(applicationId) + .jobId(jobId) + .langType(langType) + .query(query) + .queryId(queryId) + .submitTime(System.currentTimeMillis()) + .error(UNKNOWN) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } + + @Override + public String getId() { + return statementId.getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java new file mode 100644 index 0000000000..33f7f5e831 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +/** {@link Statement} State. */ +@Getter +public enum StatementState { + WAITING("waiting"), + RUNNING("running"), + SUCCESS("success"), + FAILED("failed"), + CANCELLED("cancelled"); + + private final String state; + + StatementState(String state) { + this.state = state; + } + + private static Map STATES = + Arrays.stream(StatementState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static StatementState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid statement state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java deleted file mode 100644 index 6ddce55360..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@RequiredArgsConstructor -public class SessionStateStore { - private static final Logger LOG = LogManager.getLogger(); - - private final String indexName; - private final Client client; - - public SessionModel create(SessionModel session) { - try { - IndexRequest indexRequest = - new IndexRequest(indexName) - .id(session.getSessionId().getSessionId()) - .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .setIfSeqNo(session.getSeqNo()) - .setIfPrimaryTerm(session.getPrimaryTerm()) - .create(true) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", session.getSessionId()); - return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - session.getSessionId(), - indexResponse.getResult().getLowercase())); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public Optional get(SessionId sid) { - try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - SessionModel.fromXContent( - parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java new file mode 100644 index 0000000000..b5bf31a6ba --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +public abstract class StateModel implements ToXContentObject { + + public abstract String getId(); + + public abstract long getSeqNo(); + + public abstract long getPrimaryTerm(); + + public interface CopyBuilder { + T of(T copy, long seqNo, long primaryTerm); + } + + public interface StateCopyBuilder { + T of(T copy, S state, long seqNo, long primaryTerm); + } + + public interface FromXContent { + T fromXContent(XContentParser parser, long seqNo, long primaryTerm); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java new file mode 100644 index 0000000000..bd72b17353 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +@RequiredArgsConstructor +public class StateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + protected T create(T st, StateModel.CopyBuilder builder) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(st.getId()) + .source(st.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(st.getSeqNo()) + .setIfPrimaryTerm(st.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Optional get(String sid, StateModel.FromXContent builder) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected T updateState( + T st, S state, StateModel.StateCopyBuilder builder) { + try { + T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); + UpdateRequest updateRequest = + new UpdateRequest() + .index(indexName) + .id(model.getId()) + .setIfSeqNo(model.getSeqNo()) + .setIfPrimaryTerm(model.getPrimaryTerm()) + .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .fetchSource(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** Helper Functions */ + public static Function createStatement(StateStore stateStore) { + return (st) -> stateStore.create(st, StatementModel::copy); + } + + public static Function> getStatement(StateStore stateStore) { + return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + } + + public static BiFunction updateStatementState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + } + + public static Function createSession(StateStore stateStore) { + return (session) -> stateStore.create(session, SessionModel::of); + } + + public static Function> getSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + } + + public static BiFunction updateSessionState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 53dc211ded..488252d05a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -20,7 +21,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ @@ -30,13 +31,13 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } @@ -50,7 +51,7 @@ public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(SessionId.newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -74,7 +75,7 @@ public void openSessionFailedConflict() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); @@ -82,7 +83,7 @@ public void openSessionFailedConflict() { InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); IllegalStateException exception = @@ -98,15 +99,15 @@ public void closeNotExistSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); - client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); - assertEquals("session not exist. " + sessionId, exception.getMessage()); + assertEquals("session does not exist. " + sessionId, exception.getMessage()); emrsClient.cancelJobRunCalled(0); } @@ -142,9 +143,9 @@ public void sessionManagerGetSessionNotExist() { @RequiredArgsConstructor static class TestSession { private final Session session; - private final SessionStateStore stateStore; + private final StateStore stateStore; - public static TestSession testSession(Session session, SessionStateStore stateStore) { + public static TestSession testSession(Session session, StateStore stateStore) { return new TestSession(session, stateStore); } @@ -152,7 +153,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - stateStore.get(session.getSessionModel().getSessionId()); + getSession(stateStore).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -180,7 +181,7 @@ public TestSession close() { } } - static class TestEMRServerlessClient implements EMRServerlessClient { + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; private int cancelJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index d35105f787..95b85613be 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,29 +5,20 @@ package org.opensearch.sql.spark.execution.session; -import static org.junit.jupiter.api.Assertions.*; - import org.junit.After; import org.junit.Before; -import org.mockito.MockMakers; -import org.mockito.MockSettings; -import org.mockito.Mockito; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; class SessionManagerTest extends OpenSearchSingleNodeTestCase { private static final String indexName = "mockindex"; - // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. - private static final MockSettings mockSettings = - Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); - - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java new file mode 100644 index 0000000000..b7af1123ba --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class StatementStateTest { + @Test + public void invalidStatementState() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> StatementState.fromString("invalid")); + Assertions.assertEquals("Invalid statement state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java new file mode 100644 index 0000000000..331955e14e --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -0,0 +1,356 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; +import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +public class StatementTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private StartJobRequest startJobRequest; + private StateStore stateStore; + private InteractiveSessionTest.TestEMRServerlessClient emrsClient = + new InteractiveSessionTest.TestEMRServerlessClient(); + + @Before + public void setup() { + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new StateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openThenCancelStatement() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + + @Test + public void openFailedBecauseConflict() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // open statement with same statement id + Statement dupSt = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); + assertEquals("statement already exist. statementId=statementId", exception.getMessage()); + } + + @Test + public void cancelNotExistStatement() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + client().delete(new DeleteRequest(indexName, stId.getId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("cancel statement failed. no statement found. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedBecauseOfConflict() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + StatementModel running = + updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + + assertEquals(StatementState.CANCELLED, running.getStatementState()); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format( + "cancel statement failed. current statementState: CANCELLED " + "statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelRunningStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.RUNNING, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in waiting state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void submitStatementInRunningSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void submitStatementInNotStartedState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInDeadState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " dead", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInFailState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " fail", + exception.getMessage()); + } + + @Test + public void newStatementFieldAssert() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + Optional statement = session.get(statementId); + + assertTrue(statement.isPresent()); + assertEquals(session.getSessionId(), statement.get().getSessionId()); + assertEquals("appId", statement.get().getApplicationId()); + assertEquals("jobId", statement.get().getJobId()); + assertEquals(statementId, statement.get().getStatementId()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(LangType.SQL, statement.get().getLangType()); + assertEquals("select 1", statement.get().getQuery()); + } + + @Test + public void failToSubmitStatementInDeletedSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // other's delete session + client() + .delete(new DeleteRequest(indexName, session.getSessionId().getSessionId())) + .actionGet(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); + } + + @Test + public void getStatementSuccess() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + + Optional statement = session.get(statementId); + assertTrue(statement.isPresent()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(statementId, statement.get().getStatementId()); + } + + @Test + public void getStatementNotExist() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + Optional statement = session.get(StatementId.newStatementId()); + assertFalse(statement.isPresent()); + } + + @RequiredArgsConstructor + static class TestStatement { + private final Statement st; + private final StateStore stateStore; + + public static TestStatement testStatement(Statement st, StateStore stateStore) { + return new TestStatement(st, stateStore); + } + + public TestStatement assertSessionState(StatementState expected) { + assertEquals(expected, st.getStatementModel().getStatementState()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementState()); + + return this; + } + + public TestStatement assertStatementId(StatementId expected) { + assertEquals(expected, st.getStatementModel().getStatementId()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementId()); + return this; + } + + public TestStatement open() { + st.open(); + return this; + } + + public TestStatement cancel() { + st.cancel(); + return this; + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java deleted file mode 100644 index 9c779555d7..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import static org.junit.Assert.assertThrows; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.when; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.client.Client; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@ExtendWith(MockitoExtension.class) -class SessionStateStoreTest { - @Mock(answer = RETURNS_DEEP_STUBS) - private Client client; - - @Mock private IndexResponse indexResponse; - - @Test - public void createWithException() { - when(client.index(any()).actionGet()).thenReturn(indexResponse); - doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); - SessionModel sessionModel = - SessionModel.initInteractiveSession( - "appId", "jobId", SessionId.newSessionId(), "datasource"); - SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); - - assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); - } -} From 8f5e01d47d344923b2d236ef9acaab46e036303b Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 18 Oct 2023 10:01:11 -0700 Subject: [PATCH 07/16] Add sessionId parameters for create async query API (#2312) * add InteractiveSession and SessionManager Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * fix format Signed-off-by: Peng Huo * snapshot Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo * update Signed-off-by: Peng Huo * Update REST and Transport interface Signed-off-by: Peng Huo * Revert on transport layer Signed-off-by: Peng Huo * format code Signed-off-by: Peng Huo * add API doc Signed-off-by: Peng Huo * modify api Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo * update doc Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../sql/common/setting/Settings.java | 10 +- docs/user/admin/settings.rst | 36 +++ docs/user/interfaces/asyncqueryinterface.rst | 44 ++++ .../setting/OpenSearchSettings.java | 14 + .../org/opensearch/sql/plugin/SQLPlugin.java | 9 +- .../AsyncQueryExecutorServiceImpl.java | 19 +- .../model/AsyncQueryExecutionResponse.java | 1 + .../model/AsyncQueryJobMetadata.java | 11 +- .../spark/data/constants/SparkConstants.java | 1 + .../dispatcher/SparkQueryDispatcher.java | 105 ++++++-- .../model/DispatchQueryRequest.java | 3 + .../model/DispatchQueryResponse.java | 1 + .../spark/execution/session/SessionId.java | 2 +- .../execution/session/SessionManager.java | 7 + .../execution/statement/StatementId.java | 2 +- .../rest/model/CreateAsyncQueryRequest.java | 15 +- .../rest/model/CreateAsyncQueryResponse.java | 2 + .../AsyncQueryExecutorServiceImplTest.java | 4 +- .../sql/spark/constants/TestConstants.java | 2 + .../dispatcher/SparkQueryDispatcherTest.java | 244 +++++++++++++++++- .../session/InteractiveSessionTest.java | 9 +- .../execution/session/SessionManagerTest.java | 51 +++- .../execution/statement/StatementTest.java | 17 +- .../model/CreateAsyncQueryRequestTest.java | 52 ++++ ...portCreateAsyncQueryRequestActionTest.java | 21 +- ...ransportGetAsyncQueryResultActionTest.java | 3 +- 26 files changed, 625 insertions(+), 60 deletions(-) create mode 100644 spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 8daf0e9bf6..89d046b3d9 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -5,6 +5,8 @@ package org.opensearch.sql.common.setting; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; + import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -36,7 +38,8 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), - CLUSTER_NAME("cluster.name"); + CLUSTER_NAME("cluster.name"), + SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"); @Getter private final String keyValue; @@ -60,4 +63,9 @@ public static Optional of(String keyValue) { public abstract T getSettingValue(Key key); public abstract List getSettings(); + + /** Helper class */ + public static boolean isSparkExecutionSessionEnabled(Settings settings) { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index b5da4e28e2..cd56e76491 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -311,3 +311,39 @@ SQL query:: "status": 400 } +plugins.query.executionengine.spark.session.enabled +=================================================== + +Description +----------- + +By default, execution engine is executed in job mode. You can enable session mode by this setting. + +1. The default value is false. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.enabled":"true"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "enabled": "true" + } + } + } + } + } + } + } + diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index a9fc77264c..3fbc16d15f 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -62,6 +62,50 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } +Execute query in session +------------------------ + +if plugins.query.executionengine.spark.session.enabled is set to true, session based execution is enabled. Under the hood, all queries submitted to the same session will be executed in the same SparkContext. Session is auto closed if not query submission in 10 minutes. + +Async query response include ``sessionId`` indicate the query is executed in session. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "HlbM61kX6MDkAktO", + "sessionId": "1Giy65ZnzNlmsPAm" + } + +User could reuse the session by using ``sessionId`` query parameters. + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "datasource" : "my_glue", + "lang" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10", + "sessionId" : "1Giy65ZnzNlmsPAm" + }' + +Sample Response:: + + { + "queryId": "7GC4mHhftiTejvxN", + "sessionId": "1Giy65ZnzNlmsPAm" + } + Async Query Result API ====================================== diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 76bda07607..ecb35afafa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -135,6 +135,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_SESSION_ENABLED_SETTING = + Setting.boolSetting( + Key.SPARK_EXECUTION_SESSION_ENABLED.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -205,6 +212,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.SPARK_EXECUTION_ENGINE_CONFIG, SPARK_EXECUTION_ENGINE_CONFIG, new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_ENABLED, + SPARK_EXECUTION_SESSION_ENABLED_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_ENABLED)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -270,6 +283,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(SPARK_EXECUTION_ENGINE_CONFIG) + .add(SPARK_EXECUTION_SESSION_ENABLED_SETTING) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f3fd043b63..a9a35f6318 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -99,6 +100,8 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; @@ -318,7 +321,11 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), - client); + client, + new SessionManager( + new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), + emrServerlessClient, + pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 13db103f4b..7cba2757cc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -65,14 +65,17 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameters())); + sparkExecutionEngineConfig.getSparkSubmitParameters(), + createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), - dispatchQueryResponse.getResultIndex())); - return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); + dispatchQueryResponse.getResultIndex(), + dispatchQueryResponse.getSessionId())); + return new CreateAsyncQueryResponse( + dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); } @Override @@ -81,6 +84,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { + String sessionId = jobMetadata.get().getSessionId(); JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = @@ -90,13 +94,18 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { result.add(sparkSqlFunctionResponseHandle.next()); } return new AsyncQueryExecutionResponse( - JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result, null); + JobRunState.SUCCESS.toString(), + sparkSqlFunctionResponseHandle.schema(), + result, + null, + sessionId); } else { return new AsyncQueryExecutionResponse( jsonObject.optString(STATUS_FIELD, JobRunState.FAILED.toString()), null, null, - jsonObject.optString(ERROR_FIELD, "")); + jsonObject.optString(ERROR_FIELD, ""), + sessionId); } } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java index d2e54af004..e5d9cffd5f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -19,4 +19,5 @@ public class AsyncQueryExecutionResponse { private final ExecutionEngine.Schema schema; private final List results; private final String error; + private final String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b470ef989f..b80fefa173 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -30,12 +30,15 @@ public class AsyncQueryJobMetadata { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + // optional sessionId. + private String sessionId; public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { this.applicationId = applicationId; this.jobId = jobId; this.isDropIndexQuery = false; this.resultIndex = resultIndex; + this.sessionId = null; } @Override @@ -57,6 +60,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.field("applicationId", metadata.getApplicationId()); builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); builder.field("resultIndex", metadata.getResultIndex()); + builder.field("sessionId", metadata.getSessionId()); builder.endObject(); return builder; } @@ -92,6 +96,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -109,6 +114,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "resultIndex": resultIndex = parser.textOrNull(); break; + case "sessionId": + sessionId = parser.textOrNull(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -116,6 +124,7 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery, resultIndex); + return new AsyncQueryJobMetadata( + applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 284afcc0a9..1b248eb15d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -21,6 +21,7 @@ public class SparkConstants { public static final String SPARK_SQL_APPLICATION_JAR = "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; + public static final String SPARK_REQUEST_BUFFER_INDEX_NAME = ".query_execution_request"; // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 347e154885..8d5ae10e91 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -16,6 +16,7 @@ import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import lombok.AllArgsConstructor; import lombok.Getter; @@ -39,6 +40,14 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -69,6 +78,8 @@ public class SparkQueryDispatcher { private Client client; + private SessionManager sessionManager; + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { return handleSQLQuery(dispatchQueryRequest); @@ -111,23 +122,60 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) String error = items.optString(ERROR_FIELD, ""); result.put(ERROR_FIELD, error); } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); + if (asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + StatementState statementState = statement.get().getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + } } return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + if (asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + statement.get().cancel(); + return statementId.getId(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return cancelJobRunResult.getJobRunId(); + } } private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -173,7 +221,7 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { @@ -198,8 +246,35 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ tags, false, dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { + Session session; + if (dispatchQueryRequest.getSessionId() != null) { + // get session from request + SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); + Optional createdSession = sessionManager.getSession(sessionId); + if (createdSession.isEmpty()) { + throw new IllegalArgumentException("no session found. " + sessionId); + } + session = createdSession.get(); + } else { + // create session if not exist + session = + sessionManager.createSession( + new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + } + StatementId statementId = + session.submit( + new QueryRequest( + dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + return new DispatchQueryResponse( + statementId.getId(), + false, + dataSourceMetadata.getResultIndex(), + session.getSessionId().getSessionId()); + } else { + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + } } private DispatchQueryResponse handleDropIndexQuery( @@ -229,7 +304,7 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex()); + new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 823a4570ce..6aa28227a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -23,4 +23,7 @@ public class DispatchQueryRequest { /** Optional extra Spark submit parameters to include in final request */ private String extraSparkSubmitParams; + + /** Optional sessionId. */ + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 9ee5f156f2..893446c617 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -9,4 +9,5 @@ public class DispatchQueryResponse { private String jobId; private boolean isDropIndexQuery; private String resultIndex; + private String sessionId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index a2847cde18..861d906b9b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -13,7 +13,7 @@ public class SessionId { private final String sessionId; public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.random(10, true, true)); + return new SessionId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 217af80caf..c34be7015f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -5,10 +5,12 @@ package org.opensearch.sql.spark.execution.session; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -21,6 +23,7 @@ public class SessionManager { private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; + private final Settings settings; public Session createSession(CreateSessionRequest request) { InteractiveSession session = @@ -47,4 +50,8 @@ public Optional getSession(SessionId sid) { } return Optional.empty(); } + + public boolean isEnabled() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index 4baff71493..d9381ad45f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -13,7 +13,7 @@ public class StatementId { private final String id; public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.random(10, true, true)); + return new StatementId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 8802630d9f..6acf6bc9a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.rest.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_ID; import java.io.IOException; import lombok.Data; @@ -18,6 +19,8 @@ public class CreateAsyncQueryRequest { private String query; private String datasource; private LangType lang; + // optional sessionId + private String sessionId; public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.query = Validate.notNull(query, "Query can't be null"); @@ -25,11 +28,19 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang) { this.lang = Validate.notNull(lang, "lang can't be null"); } + public CreateAsyncQueryRequest(String query, String datasource, LangType lang, String sessionId) { + this.query = Validate.notNull(query, "Query can't be null"); + this.datasource = Validate.notNull(datasource, "Datasource can't be null"); + this.lang = Validate.notNull(lang, "lang can't be null"); + this.sessionId = sessionId; + } + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; String datasource = null; + String sessionId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -41,10 +52,12 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) lang = LangType.fromString(langString); } else if (fieldName.equals("datasource")) { datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateAsyncQueryRequest(query, datasource, lang); + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java index 8cfe57c2a6..2f918308c4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java @@ -12,4 +12,6 @@ @AllArgsConstructor public class CreateAsyncQueryResponse { private String queryId; + // optional sessionId + private String sessionId; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 01bccd9030..0d4e280b61 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -78,7 +78,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) @@ -107,7 +107,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index abae0377a2..3a0d8fc56d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -16,4 +16,6 @@ public class TestConstants { public static final String EMRS_JOB_NAME = "job_name"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; + public static final String MOCK_SESSION_ID = "s-0123456"; + public static final String MOCK_STATEMENT_ID = "st-0123456"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 8c0ecb2ea2..58fe626dae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -8,8 +8,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -18,6 +22,8 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_STATEMENT_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; @@ -34,6 +40,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; @@ -58,6 +65,12 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; @@ -78,6 +91,12 @@ public class SparkQueryDispatcherTest { @Mock private FlintIndexMetadata flintIndexMetadata; + @Mock private SessionManager sessionManager; + + @Mock private Session session; + + @Mock private Statement statement; + private SparkQueryDispatcher sparkQueryDispatcher; @Captor ArgumentCaptor startJobRequestArgumentCaptor; @@ -91,7 +110,8 @@ void setUp() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); } @Test @@ -256,6 +276,84 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchSelectQueryCreateNewSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(session).when(sessionManager).createSession(any()); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).getSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryReuseSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, MOCK_SESSION_ID); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.of(session)) + .when(sessionManager) + .getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); + } + + @Test + void testDispatchSelectQueryInvalidSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); + + doReturn(true).when(sessionManager).isEnabled(); + doReturn(Optional.empty()).when(sessionManager).getSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + verify(sessionManager, never()).createSession(any()); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testDispatchSelectQueryFailedCreateSession() { + String query = "select * from my_glue.default.http_logs"; + DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); + + doReturn(true).when(sessionManager).isEnabled(); + doThrow(RuntimeException.class).when(sessionManager).createSession(any()); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + Assertions.assertThrows( + RuntimeException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + + verifyNoInteractions(emrServerlessClient); + } + @Test void testDispatchIndexQuery() { HashMap tags = new HashMap<>(); @@ -544,6 +642,68 @@ void testCancelJob() { Assertions.assertEquals(EMR_JOB_ID, jobId); } + @Test + void testCancelQueryWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doNothing().when(statement).cancel(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + verify(statement, times(1)).cancel(); + Assertions.assertEquals(MOCK_STATEMENT_ID, queryId); + } + + @Test + void testCancelQueryWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(new SessionId("invalid")); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(session); + Assertions.assertEquals( + "no session found. " + new SessionId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithInvalidStatementId() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(statement); + Assertions.assertEquals( + "no statement found. " + new StatementId("invalid"), exception.getMessage()); + } + + @Test + void testCancelQueryWithNoSessionId() { + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + @Test void testGetQueryResponse() { when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) @@ -558,6 +718,60 @@ void testGetQueryResponse() { Assertions.assertEquals("PENDING", result.get("status")); } + @Test + void testGetQueryResponseWithSession() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.of(statement)).when(session).get(any()); + doReturn(StatementState.WAITING).when(statement).getStatementState(); + + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals("waiting", result.get("status")); + } + + @Test + void testGetQueryResponseWithInvalidSession() { + doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no session found. " + new SessionId(MOCK_SESSION_ID), exception.getMessage()); + } + + @Test + void testGetQueryResponseWithStatementNotExist() { + doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); + doReturn(Optional.empty()).when(session).get(any()); + doReturn(new JSONObject()) + .when(jobExecutionResponseReader) + .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.getQueryResponse( + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); + Assertions.assertEquals( + "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); + } + @Test void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = @@ -567,7 +781,8 @@ void testGetQueryResponseWithSuccess() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); JSONObject queryResult = new JSONObject(); Map resultMap = new HashMap<>(); resultMap.put(STATUS_FIELD, "SUCCESS"); @@ -604,14 +819,15 @@ void testGetQueryResponseOfDropIndex() { dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader, - openSearchClient); + openSearchClient, + sessionManager); String jobId = new SparkQueryDispatcher.DropIndexResult(JobRunState.SUCCESS.toString()).toJobId(); JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -978,6 +1194,24 @@ private DispatchQueryRequest constructDispatchQueryRequest( langType, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME, - extraParameters); + extraParameters, + null); + } + + private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + null, + sessionId); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( + String queryId, String sessionId) { + return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 488252d05a..429c970365 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -114,7 +115,7 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); TestSession testSession = testSession(session, stateStore); @@ -123,7 +124,8 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Session session = sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); @@ -134,7 +136,8 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + SessionManager sessionManager = + new SessionManager(stateStore, emrsClient, sessionSetting(false)); Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); assertTrue(managerSession.isEmpty()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 95b85613be..4374bd4f11 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,25 +5,48 @@ package org.opensearch.sql.spark.execution.session; -import org.junit.After; -import org.junit.Before; -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; -class SessionManagerTest extends OpenSearchSingleNodeTestCase { - private static final String indexName = "mockindex"; +@ExtendWith(MockitoExtension.class) +public class SessionManagerTest { + @Mock private StateStore stateStore; + @Mock private EMRServerlessClient emrClient; - private StateStore stateStore; + @Test + public void sessionEnable() { + Assertions.assertTrue( + new SessionManager(stateStore, emrClient, sessionSetting(true)).isEnabled()); + Assertions.assertFalse( + new SessionManager(stateStore, emrClient, sessionSetting(false)).isEnabled()); + } - @Before - public void setup() { - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + public static Settings sessionSetting(boolean enabled) { + Map settings = new HashMap<>(); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_ENABLED, enabled); + return settings(settings); } - @After - public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + public static Settings settings(Map settings) { + return new Settings() { + @Override + public T getSettingValue(Key key) { + return (T) settings.get(key); + } + + @Override + public List getSettings() { + return (List) settings; + } + }; } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 331955e14e..214bcb8258 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; @@ -196,7 +197,7 @@ public void cancelRunningStatementFailed() { @Test public void submitStatementInRunningSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running @@ -209,7 +210,7 @@ public void submitStatementInRunningSession() { @Test public void submitStatementInNotStartedState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); @@ -219,7 +220,7 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); @@ -237,7 +238,7 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); @@ -255,7 +256,7 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -273,7 +274,7 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // other's delete session @@ -291,7 +292,7 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); @@ -306,7 +307,7 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { Session session = - new SessionManager(stateStore, emrsClient) + new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); // App change state to running updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); diff --git a/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java new file mode 100644 index 0000000000..dd634d6055 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import java.io.IOException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; + +public class CreateAsyncQueryRequestTest { + + @Test + public void fromXContent() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("my_glue", queryRequest.getDatasource()); + Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); + Assertions.assertEquals("select 1", queryRequest.getQuery()); + } + + @Test + public void fromXContentWithSessionId() throws IOException { + String request = + "{\n" + + " \"datasource\": \"my_glue\",\n" + + " \"lang\": \"sql\",\n" + + " \"query\": \"select 1\",\n" + + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + + "}"; + CreateAsyncQueryRequest queryRequest = + CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); + } + + private XContentParser xContentParser(String request) throws IOException { + return XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, request); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 8599e4b88e..36060d3850 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.MOCK_SESSION_ID; import java.util.HashSet; import org.junit.jupiter.api.Assertions; @@ -61,7 +62,7 @@ public void testDoExecute() { CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) - .thenReturn(new CreateAsyncQueryResponse("123")); + .thenReturn(new CreateAsyncQueryResponse("123", null)); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateAsyncQueryActionResponse createAsyncQueryActionResponse = @@ -70,6 +71,24 @@ public void testDoExecute() { "{\n" + " \"queryId\": \"123\"\n" + "}", createAsyncQueryActionResponse.getResult()); } + @Test + public void testDoExecuteWithSessionId() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest( + "source = my_glue.default.alb_logs", "my_glue", LangType.SQL, MOCK_SESSION_ID); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + .thenReturn(new CreateAsyncQueryResponse("123", MOCK_SESSION_ID)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateAsyncQueryActionResponse createAsyncQueryActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"queryId\": \"123\",\n" + " \"sessionId\": \"s-0123456\"\n" + "}", + createAsyncQueryActionResponse.getResult()); + } + @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 21a213c7c2..34f10b0083 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -63,7 +63,7 @@ public void setUp() { public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null); + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); @@ -89,6 +89,7 @@ public void testDoExecuteWithSuccessResponse() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), + null, null); when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); action.doExecute(task, request, actionListener); From f835112bc1d60b4f814baa0fa68512c27d4343b3 Mon Sep 17 00:00:00 2001 From: Derek Ho Date: Wed, 18 Oct 2023 19:34:43 -0400 Subject: [PATCH 08/16] Implement patch API for datasources (#2273) * Implement patch API for datasources Signed-off-by: Derek Ho * Change patch implementation to Map Signed-off-by: Derek Ho * Fix up, everything complete except unit test Signed-off-by: Derek Ho * Revise PR to use existing functions Signed-off-by: Derek Ho * Remove unused utility function Signed-off-by: Derek Ho * Add tests Signed-off-by: Derek Ho * Add back line Signed-off-by: Derek Ho * fix build issue Signed-off-by: Derek Ho * Fix tests and add in rst Signed-off-by: Derek Ho * Register patch Signed-off-by: Derek Ho * Add imports Signed-off-by: Derek Ho * Patch Signed-off-by: Derek Ho * Fix integration test Signed-off-by: Derek Ho * Update IT Signed-off-by: Derek Ho * Add tests Signed-off-by: Derek Ho * Fix test Signed-off-by: Derek Ho * Fix tests and increase code cov Signed-off-by: Derek Ho * Add more coverage to impl Signed-off-by: Derek Ho * Fix test and jacoco passing Signed-off-by: Derek Ho * Test fix Signed-off-by: Derek Ho * Add docs Signed-off-by: Derek Ho --------- Signed-off-by: Derek Ho --- .../sql/datasource/DataSourceService.java | 10 ++- .../sql/analysis/AnalyzerTestBase.java | 3 + .../PatchDataSourceActionRequest.java | 49 ++++++++++++ .../PatchDataSourceActionResponse.java | 31 ++++++++ .../rest/RestDataSourceQueryAction.java | 61 ++++++++++---- .../service/DataSourceServiceImpl.java | 50 ++++++++++-- .../TransportPatchDataSourceAction.java | 74 +++++++++++++++++ .../utils/XContentParserUtils.java | 72 +++++++++++++++++ .../service/DataSourceServiceImplTest.java | 42 +++++++++- .../TransportPatchDataSourceActionTest.java | 79 +++++++++++++++++++ .../utils/XContentParserUtilsTest.java | 74 +++++++++++++++++ docs/user/ppl/admin/datasources.rst | 14 ++++ .../sql/datasource/DataSourceAPIsIT.java | 29 +++++++ .../sql/legacy/SQLIntegTestCase.java | 10 +++ .../org/opensearch/sql/plugin/SQLPlugin.java | 14 ++-- 15 files changed, 580 insertions(+), 32 deletions(-) create mode 100644 datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java create mode 100644 datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java create mode 100644 datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java create mode 100644 datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 6dace50f99..162fe9e8f8 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -5,6 +5,7 @@ package org.opensearch.sql.datasource; +import java.util.Map; import java.util.Set; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; @@ -56,12 +57,19 @@ public interface DataSourceService { void createDataSource(DataSourceMetadata metadata); /** - * Updates {@link DataSource} corresponding to dataSourceMetadata. + * Updates {@link DataSource} corresponding to dataSourceMetadata (all fields needed). * * @param dataSourceMetadata {@link DataSourceMetadata}. */ void updateDataSource(DataSourceMetadata dataSourceMetadata); + /** + * Patches {@link DataSource} corresponding to the given name (only fields to be changed needed). + * + * @param dataSourceData + */ + void patchDataSource(Map dataSourceData); + /** * Deletes {@link DataSource} corresponding to the DataSource name. * diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 508567582b..569cdd96f8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -231,6 +231,9 @@ public DataSource getDataSource(String dataSourceName) { @Override public void updateDataSource(DataSourceMetadata dataSourceMetadata) {} + @Override + public void patchDataSource(Map dataSourceData) {} + @Override public void deleteDataSource(String dataSourceName) {} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java new file mode 100644 index 0000000000..9443ea561e --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionRequest.java @@ -0,0 +1,49 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.model.transport; + +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.CONNECTOR_FIELD; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; + +import java.io.IOException; +import java.util.Map; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +public class PatchDataSourceActionRequest extends ActionRequest { + + @Getter private Map dataSourceData; + + /** Constructor of UpdateDataSourceActionRequest from StreamInput. */ + public PatchDataSourceActionRequest(StreamInput in) throws IOException { + super(in); + } + + public PatchDataSourceActionRequest(Map dataSourceData) { + this.dataSourceData = dataSourceData; + } + + @Override + public ActionRequestValidationException validate() { + if (this.dataSourceData.get(NAME_FIELD).equals(DEFAULT_DATASOURCE_NAME)) { + ActionRequestValidationException exception = new ActionRequestValidationException(); + exception.addValidationError( + "Not allowed to update datasource with name : " + DEFAULT_DATASOURCE_NAME); + return exception; + } else if (this.dataSourceData.get(CONNECTOR_FIELD) != null) { + ActionRequestValidationException exception = new ActionRequestValidationException(); + exception.addValidationError("Not allowed to update connector for datasource"); + return exception; + } else { + return null; + } + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java new file mode 100644 index 0000000000..18873a6731 --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/model/transport/PatchDataSourceActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.model.transport; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class PatchDataSourceActionResponse extends ActionResponse { + + @Getter private final String result; + + public PatchDataSourceActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index 2947afc5b9..c207f55738 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -10,15 +10,13 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.NOT_FOUND; import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; -import static org.opensearch.rest.RestRequest.Method.DELETE; -import static org.opensearch.rest.RestRequest.Method.GET; -import static org.opensearch.rest.RestRequest.Method.POST; -import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.rest.RestRequest.Method.*; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -32,18 +30,8 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.datasources.exceptions.ErrorMessage; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionRequest; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionResponse; -import org.opensearch.sql.datasources.transport.TransportCreateDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportDeleteDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportGetDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportUpdateDataSourceAction; +import org.opensearch.sql.datasources.model.transport.*; +import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.datasources.utils.XContentParserUtils; @@ -98,6 +86,17 @@ public List routes() { */ new Route(PUT, BASE_DATASOURCE_ACTION_URL), + /* + * PATCH datasources + * Request body: + * Ref + * [org.opensearch.sql.plugin.transport.datasource.model.PatchDataSourceActionRequest] + * Response body: + * Ref + * [org.opensearch.sql.plugin.transport.datasource.model.PatchDataSourceActionResponse] + */ + new Route(PATCH, BASE_DATASOURCE_ACTION_URL), + /* * DELETE datasources * Request body: Ref @@ -122,6 +121,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient return executeUpdateRequest(restRequest, nodeClient); case DELETE: return executeDeleteRequest(restRequest, nodeClient); + case PATCH: + return executePatchRequest(restRequest, nodeClient); default: return restChannel -> restChannel.sendResponse( @@ -216,6 +217,34 @@ public void onFailure(Exception e) { })); } + private RestChannelConsumer executePatchRequest(RestRequest restRequest, NodeClient nodeClient) + throws IOException { + Map dataSourceData = XContentParserUtils.toMap(restRequest.contentParser()); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportPatchDataSourceAction.ACTION_TYPE, + new PatchDataSourceActionRequest(dataSourceData), + new ActionListener<>() { + @Override + public void onResponse( + PatchDataSourceActionResponse patchDataSourceActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.OK, + "application/json; charset=UTF-8", + patchDataSourceActionResponse.getResult())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { String dataSourceName = restRequest.param("dataSourceName"); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 25e8006d66..8ba618fb44 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -6,15 +6,11 @@ package org.opensearch.sql.datasources.service; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.common.base.Preconditions; import com.google.common.base.Strings; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSource; @@ -100,6 +96,19 @@ public void updateDataSource(DataSourceMetadata dataSourceMetadata) { } } + @Override + public void patchDataSource(Map dataSourceData) { + if (!dataSourceData.get(NAME_FIELD).equals(DEFAULT_DATASOURCE_NAME)) { + DataSourceMetadata dataSourceMetadata = + getRawDataSourceMetadata((String) dataSourceData.get(NAME_FIELD)); + replaceOldDatasourceMetadata(dataSourceData, dataSourceMetadata); + updateDataSource(dataSourceMetadata); + } else { + throw new UnsupportedOperationException( + "Not allowed to update default datasource :" + DEFAULT_DATASOURCE_NAME); + } + } + @Override public void deleteDataSource(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { @@ -136,6 +145,35 @@ private void validateDataSourceMetaData(DataSourceMetadata metadata) { + " Properties are required parameters."); } + /** + * Replaces the fields in the map of the given metadata. + * + * @param dataSourceData + * @param metadata {@link DataSourceMetadata}. + */ + private void replaceOldDatasourceMetadata( + Map dataSourceData, DataSourceMetadata metadata) { + + for (String key : dataSourceData.keySet()) { + switch (key) { + // Name and connector should not be modified + case DESCRIPTION_FIELD: + metadata.setDescription((String) dataSourceData.get(DESCRIPTION_FIELD)); + break; + case ALLOWED_ROLES_FIELD: + metadata.setAllowedRoles((List) dataSourceData.get(ALLOWED_ROLES_FIELD)); + break; + case PROPERTIES_FIELD: + Map properties = new HashMap<>(metadata.getProperties()); + properties.putAll(((Map) dataSourceData.get(PROPERTIES_FIELD))); + break; + case NAME_FIELD: + case CONNECTOR_FIELD: + break; + } + } + } + @Override public DataSourceMetadata getRawDataSourceMetadata(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java new file mode 100644 index 0000000000..303e905cec --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceAction.java @@ -0,0 +1,74 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.datasources.transport; + +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; +import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionRequest; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionResponse; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportPatchDataSourceAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/datasources/patch"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, PatchDataSourceActionResponse::new); + + private DataSourceService dataSourceService; + + /** + * TransportPatchDataSourceAction action for updating datasource. + * + * @param transportService transportService. + * @param actionFilters actionFilters. + * @param dataSourceService dataSourceService. + */ + @Inject + public TransportPatchDataSourceAction( + TransportService transportService, + ActionFilters actionFilters, + DataSourceServiceImpl dataSourceService) { + super( + TransportPatchDataSourceAction.NAME, + transportService, + actionFilters, + PatchDataSourceActionRequest::new); + this.dataSourceService = dataSourceService; + } + + @Override + protected void doExecute( + Task task, + PatchDataSourceActionRequest request, + ActionListener actionListener) { + try { + dataSourceService.patchDataSource(request.getDataSourceData()); + String responseContent = + new JsonResponseFormatter(PRETTY) { + @Override + protected Object buildJsonObject(String response) { + return response; + } + }.format("Updated DataSource with name " + request.getDataSourceData().get(NAME_FIELD)); + actionListener.onResponse(new PatchDataSourceActionResponse(responseContent)); + } catch (Exception e) { + actionListener.onFailure(e); + } + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java index 261f13870a..6af2a5a761 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/utils/XContentParserUtils.java @@ -90,6 +90,59 @@ public static DataSourceMetadata toDataSourceMetadata(XContentParser parser) thr name, description, connector, allowedRoles, properties, resultIndex); } + public static Map toMap(XContentParser parser) throws IOException { + Map resultMap = new HashMap<>(); + String name; + String description; + List allowedRoles = new ArrayList<>(); + Map properties = new HashMap<>(); + String resultIndex; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case NAME_FIELD: + name = parser.textOrNull(); + resultMap.put(NAME_FIELD, name); + break; + case DESCRIPTION_FIELD: + description = parser.textOrNull(); + resultMap.put(DESCRIPTION_FIELD, description); + break; + case CONNECTOR_FIELD: + // no-op - datasource connector should not be modified + break; + case ALLOWED_ROLES_FIELD: + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + allowedRoles.add(parser.text()); + } + resultMap.put(ALLOWED_ROLES_FIELD, allowedRoles); + break; + case PROPERTIES_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String key = parser.currentName(); + parser.nextToken(); + String value = parser.textOrNull(); + properties.put(key, value); + } + resultMap.put(PROPERTIES_FIELD, properties); + break; + case RESULT_INDEX_FIELD: + resultIndex = parser.textOrNull(); + resultMap.put(RESULT_INDEX_FIELD, resultIndex); + break; + default: + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + if (resultMap.get(NAME_FIELD) == null || resultMap.get(NAME_FIELD) == "") { + throw new IllegalArgumentException("Name is a required field."); + } + return resultMap; + } + /** * Converts json string to DataSourceMetadata. * @@ -109,6 +162,25 @@ public static DataSourceMetadata toDataSourceMetadata(String json) throws IOExce } } + /** + * Converts json string to Map. + * + * @param json jsonstring. + * @return DataSourceData + * @throws IOException IOException. + */ + public static Map toMap(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toMap(parser); + } + } + /** * Converts DataSourceMetadata to XContentBuilder. * diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 6164d8b73f..c62e586dae 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; @@ -264,7 +265,6 @@ void testGetDataSourceMetadataSetWithDefaultDatasource() { @Test void testUpdateDataSourceSuccessCase() { - DataSourceMetadata dataSourceMetadata = metadata("testDS", DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); dataSourceService.updateDataSource(dataSourceMetadata); @@ -289,6 +289,46 @@ void testUpdateDefaultDataSource() { unsupportedOperationException.getMessage()); } + @Test + void testPatchDefaultDataSource() { + Map dataSourceData = + Map.of(NAME_FIELD, DEFAULT_DATASOURCE_NAME, DESCRIPTION_FIELD, "test"); + UnsupportedOperationException unsupportedOperationException = + assertThrows( + UnsupportedOperationException.class, + () -> dataSourceService.patchDataSource(dataSourceData)); + assertEquals( + "Not allowed to update default datasource :" + DEFAULT_DATASOURCE_NAME, + unsupportedOperationException.getMessage()); + } + + @Test + void testPatchDataSourceSuccessCase() { + // Tests that patch underlying implementation is to call update + Map dataSourceData = + new HashMap<>( + Map.of( + NAME_FIELD, + "testDS", + DESCRIPTION_FIELD, + "test", + CONNECTOR_FIELD, + "PROMETHEUS", + ALLOWED_ROLES_FIELD, + new ArrayList<>(), + PROPERTIES_FIELD, + Map.of(), + RESULT_INDEX_FIELD, + "")); + DataSourceMetadata getData = + metadata("testDS", DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.ofNullable(getData)); + + dataSourceService.patchDataSource(dataSourceData); + verify(dataSourceMetadataStorage, times(1)).updateDataSourceMetadata(any()); + } + @Test void testDeleteDatasource() { dataSourceService.deleteDataSource("testDS"); diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java new file mode 100644 index 0000000000..5e1e7df112 --- /dev/null +++ b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportPatchDataSourceActionTest.java @@ -0,0 +1,79 @@ +package org.opensearch.sql.datasources.transport; + +import static org.mockito.Mockito.*; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionRequest; +import org.opensearch.sql.datasources.model.transport.PatchDataSourceActionResponse; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportPatchDataSourceActionTest { + + @Mock private TransportService transportService; + @Mock private TransportPatchDataSourceAction action; + @Mock private DataSourceServiceImpl dataSourceService; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor + private ArgumentCaptor patchDataSourceActionResponseArgumentCaptor; + + @Captor private ArgumentCaptor exceptionArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportPatchDataSourceAction( + transportService, new ActionFilters(new HashSet<>()), dataSourceService); + } + + @Test + public void testDoExecute() { + Map dataSourceData = new HashMap<>(); + dataSourceData.put(NAME_FIELD, "test_datasource"); + dataSourceData.put(DESCRIPTION_FIELD, "test"); + + PatchDataSourceActionRequest request = new PatchDataSourceActionRequest(dataSourceData); + + action.doExecute(task, request, actionListener); + verify(dataSourceService, times(1)).patchDataSource(dataSourceData); + Mockito.verify(actionListener) + .onResponse(patchDataSourceActionResponseArgumentCaptor.capture()); + PatchDataSourceActionResponse patchDataSourceActionResponse = + patchDataSourceActionResponseArgumentCaptor.getValue(); + String responseAsJson = "\"Updated DataSource with name test_datasource\""; + Assertions.assertEquals(responseAsJson, patchDataSourceActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + Map dataSourceData = new HashMap<>(); + dataSourceData.put(NAME_FIELD, "test_datasource"); + dataSourceData.put(DESCRIPTION_FIELD, "test"); + doThrow(new RuntimeException("Error")).when(dataSourceService).patchDataSource(dataSourceData); + PatchDataSourceActionRequest request = new PatchDataSourceActionRequest(dataSourceData); + action.doExecute(task, request, actionListener); + verify(dataSourceService, times(1)).patchDataSource(dataSourceData); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); + } +} diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java index d134293456..e1e442d12b 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/utils/XContentParserUtilsTest.java @@ -1,6 +1,7 @@ package org.opensearch.sql.datasources.utils; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.*; import com.google.gson.Gson; import java.util.HashMap; @@ -52,6 +53,46 @@ public void testToDataSourceMetadataFromJson() { Assertions.assertEquals("prometheus_access", retrievedMetadata.getAllowedRoles().get(0)); } + @SneakyThrows + @Test + public void testToMapFromJson() { + Map dataSourceData = + Map.of( + NAME_FIELD, + "test_DS", + DESCRIPTION_FIELD, + "test", + ALLOWED_ROLES_FIELD, + List.of("all_access"), + PROPERTIES_FIELD, + Map.of("prometheus.uri", "localhost:9090"), + CONNECTOR_FIELD, + "PROMETHEUS", + RESULT_INDEX_FIELD, + ""); + + Map dataSourceDataConnectorRemoved = + Map.of( + NAME_FIELD, + "test_DS", + DESCRIPTION_FIELD, + "test", + ALLOWED_ROLES_FIELD, + List.of("all_access"), + PROPERTIES_FIELD, + Map.of("prometheus.uri", "localhost:9090"), + RESULT_INDEX_FIELD, + ""); + + Gson gson = new Gson(); + String json = gson.toJson(dataSourceData); + + Map parsedData = XContentParserUtils.toMap(json); + + Assertions.assertEquals(parsedData, dataSourceDataConnectorRemoved); + Assertions.assertEquals("test", parsedData.get(DESCRIPTION_FIELD)); + } + @SneakyThrows @Test public void testToDataSourceMetadataFromJsonWithoutName() { @@ -71,6 +112,22 @@ public void testToDataSourceMetadataFromJsonWithoutName() { Assertions.assertEquals("name and connector are required fields.", exception.getMessage()); } + @SneakyThrows + @Test + public void testToMapFromJsonWithoutName() { + Map dataSourceData = new HashMap<>(Map.of(DESCRIPTION_FIELD, "test")); + Gson gson = new Gson(); + String json = gson.toJson(dataSourceData); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + XContentParserUtils.toMap(json); + }); + Assertions.assertEquals("Name is a required field.", exception.getMessage()); + } + @SneakyThrows @Test public void testToDataSourceMetadataFromJsonWithoutConnector() { @@ -106,4 +163,21 @@ public void testToDataSourceMetadataFromJsonUsingUnknownObject() { }); Assertions.assertEquals("Unknown field: test", exception.getMessage()); } + + @SneakyThrows + @Test + public void testToMapFromJsonUsingUnknownObject() { + HashMap hashMap = new HashMap<>(); + hashMap.put("test", "test"); + Gson gson = new Gson(); + String json = gson.toJson(hashMap); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + XContentParserUtils.toMap(json); + }); + Assertions.assertEquals("Unknown field: test", exception.getMessage()); + } } diff --git a/docs/user/ppl/admin/datasources.rst b/docs/user/ppl/admin/datasources.rst index 3682153b9d..31378f6cc4 100644 --- a/docs/user/ppl/admin/datasources.rst +++ b/docs/user/ppl/admin/datasources.rst @@ -93,6 +93,19 @@ we can remove authorization and other details in case of security disabled domai "allowedRoles" : ["prometheus_access"] } +* Datasource modification PATCH API ("_plugins/_query/_datasources") :: + + PATCH https://localhost:9200/_plugins/_query/_datasources + content-type: application/json + Authorization: Basic {{username}} {{password}} + + { + "name" : "my_prometheus", + "allowedRoles" : ["all_access"] + } + + **Name is required and must exist. Connector cannot be modified and will be ignored.** + * Datasource Read GET API("_plugins/_query/_datasources/{{dataSourceName}}" :: GET https://localhost:9200/_plugins/_query/_datasources/my_prometheus @@ -114,6 +127,7 @@ Each of the datasource configuration management apis are controlled by following * cluster:admin/opensearch/datasources/create [Create POST API] * cluster:admin/opensearch/datasources/read [Get GET API] * cluster:admin/opensearch/datasources/update [Update PUT API] +* cluster:admin/opensearch/datasources/patch [Update PATCH API] * cluster:admin/opensearch/datasources/delete [Delete DELETE API] Only users mapped with roles having above actions are authorized to execute datasource management apis. diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 8623b9fa6f..ff36d2a887 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -5,6 +5,8 @@ package org.opensearch.sql.datasource; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.DESCRIPTION_FIELD; +import static org.opensearch.sql.datasources.utils.XContentParserUtils.NAME_FIELD; import static org.opensearch.sql.legacy.TestUtils.getResponseBody; import com.google.common.collect.ImmutableList; @@ -15,7 +17,9 @@ import java.io.IOException; import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import lombok.SneakyThrows; import org.junit.AfterClass; import org.junit.Assert; @@ -136,6 +140,31 @@ public void updateDataSourceAPITest() { Assert.assertEquals( "https://randomtest.com:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); Assert.assertEquals("", dataSourceMetadata.getDescription()); + + // patch datasource + Map updateDS = + new HashMap<>(Map.of(NAME_FIELD, "update_prometheus", DESCRIPTION_FIELD, "test")); + Request patchRequest = getPatchDataSourceRequest(updateDS); + Response patchResponse = client().performRequest(patchRequest); + Assert.assertEquals(200, patchResponse.getStatusLine().getStatusCode()); + String patchResponseString = getResponseBody(patchResponse); + Assert.assertEquals("\"Updated DataSource with name update_prometheus\"", patchResponseString); + + // Datasource is not immediately updated. so introducing a sleep of 2s. + Thread.sleep(2000); + + // get datasource to validate the modification. + // get datasource + Request getRequestAfterPatch = getFetchDataSourceRequest("update_prometheus"); + Response getResponseAfterPatch = client().performRequest(getRequestAfterPatch); + Assert.assertEquals(200, getResponseAfterPatch.getStatusLine().getStatusCode()); + String getResponseStringAfterPatch = getResponseBody(getResponseAfterPatch); + DataSourceMetadata dataSourceMetadataAfterPatch = + new Gson().fromJson(getResponseStringAfterPatch, DataSourceMetadata.class); + Assert.assertEquals( + "https://randomtest.com:9090", + dataSourceMetadataAfterPatch.getProperties().get("prometheus.uri")); + Assert.assertEquals("test", dataSourceMetadataAfterPatch.getDescription()); } @SneakyThrows diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index 8335ada5a7..058182f123 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -49,6 +49,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Locale; +import java.util.Map; import javax.management.MBeanServerInvocationHandler; import javax.management.ObjectName; import javax.management.remote.JMXConnector; @@ -488,6 +489,15 @@ protected static Request getUpdateDataSourceRequest(DataSourceMetadata dataSourc return request; } + protected static Request getPatchDataSourceRequest(Map dataSourceData) { + Request request = new Request("PATCH", "/_plugins/_query/_datasources"); + request.setJsonEntity(new Gson().toJson(dataSourceData)); + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + return request; + } + protected static Request getFetchDataSourceRequest(String name) { Request request = new Request("GET", "/_plugins/_query/_datasources" + "/" + name); if (StringUtils.isEmpty(name)) { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index a9a35f6318..eb6eabf988 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -59,18 +59,12 @@ import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.datasources.encryptor.EncryptorImpl; import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; -import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.GetDataSourceActionResponse; -import org.opensearch.sql.datasources.model.transport.UpdateDataSourceActionResponse; +import org.opensearch.sql.datasources.model.transport.*; import org.opensearch.sql.datasources.rest.RestDataSourceQueryAction; import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; import org.opensearch.sql.datasources.service.DataSourceServiceImpl; import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; -import org.opensearch.sql.datasources.transport.TransportCreateDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportDeleteDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportGetDataSourceAction; -import org.opensearch.sql.datasources.transport.TransportUpdateDataSourceAction; +import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; import org.opensearch.sql.legacy.metrics.Metrics; @@ -183,6 +177,10 @@ public List getRestHandlers( new ActionType<>( TransportUpdateDataSourceAction.NAME, UpdateDataSourceActionResponse::new), TransportUpdateDataSourceAction.class), + new ActionHandler<>( + new ActionType<>( + TransportPatchDataSourceAction.NAME, PatchDataSourceActionResponse::new), + TransportPatchDataSourceAction.class), new ActionHandler<>( new ActionType<>( TransportDeleteDataSourceAction.NAME, DeleteDataSourceActionResponse::new), From 7b4156e0ad3b9194cc0bf59f43971e67c3941aae Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 20 Oct 2023 10:01:35 -0700 Subject: [PATCH 09/16] Integration with REPL Spark job (#2327) * add InteractiveSession and SessionManager Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * add statement Signed-off-by: Peng Huo * fix format Signed-off-by: Peng Huo * snapshot Signed-off-by: Peng Huo * address comments Signed-off-by: Peng Huo * update Signed-off-by: Peng Huo * Update REST and Transport interface Signed-off-by: Peng Huo * Revert on transport layer Signed-off-by: Peng Huo * format code Signed-off-by: Peng Huo * add API doc Signed-off-by: Peng Huo * modify api Signed-off-by: Peng Huo * create query_execution_request index on demand Signed-off-by: Peng Huo * add REPL spark parameters Signed-off-by: Peng Huo * Add IT Signed-off-by: Peng Huo * format code Signed-off-by: Peng Huo * bind request index to datasource Signed-off-by: Peng Huo * fix bug when fetch query result Signed-off-by: Peng Huo * revert entrypoint class Signed-off-by: Peng Huo * update mapping Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../org/opensearch/sql/plugin/SQLPlugin.java | 5 +- spark/build.gradle | 1 + .../model/SparkSubmitParameters.java | 14 +- .../spark/data/constants/SparkConstants.java | 4 + .../dispatcher/SparkQueryDispatcher.java | 67 ++-- .../session/CreateSessionRequest.java | 21 +- .../execution/session/InteractiveSession.java | 15 +- .../spark/execution/session/SessionId.java | 21 +- .../execution/session/SessionManager.java | 5 +- .../spark/execution/session/SessionState.java | 7 +- .../spark/execution/session/SessionType.java | 14 +- .../spark/execution/statement/Statement.java | 20 +- .../execution/statement/StatementModel.java | 10 + .../execution/statement/StatementState.java | 7 +- .../execution/statestore/StateStore.java | 203 +++++++--- .../response/JobExecutionResponseReader.java | 4 + .../query_execution_request_mapping.yml | 40 ++ .../query_execution_request_settings.yml | 11 + ...AsyncQueryExecutorServiceImplSpecTest.java | 374 ++++++++++++++++++ .../dispatcher/SparkQueryDispatcherTest.java | 6 +- .../session/InteractiveSessionTest.java | 55 ++- .../execution/statement/StatementTest.java | 63 +-- 22 files changed, 810 insertions(+), 157 deletions(-) create mode 100644 spark/src/main/resources/query_execution_request_mapping.yml create mode 100644 spark/src/main/resources/query_execution_request_settings.yml create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index eb6eabf988..f714a8366b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.emrserverless.AWSEMRServerless; @@ -321,9 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataReaderImpl(client), client, new SessionManager( - new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client), - emrServerlessClient, - pluginSettings)); + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/build.gradle b/spark/build.gradle index 15f1e200e0..8f4388495e 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -68,6 +68,7 @@ dependencies { because 'allows tests to run from IDEs that bundle older version of launcher' } testImplementation("org.opensearch.test:framework:${opensearch_version}") + testImplementation project(':opensearch') } test { diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 0609d8903c..db78abb2a8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; import static org.opensearch.sql.spark.data.constants.SparkConstants.*; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.net.URI; import java.net.URISyntaxException; @@ -39,7 +40,7 @@ public class SparkSubmitParameters { public static class Builder { - private final String className; + private String className; private final Map config; private String extraParameters; @@ -70,6 +71,11 @@ public static Builder builder() { return new Builder(); } + public Builder className(String className) { + this.className = className; + return this; + } + public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); @@ -141,6 +147,12 @@ public Builder extraParameters(String params) { return this; } + public Builder sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + config.put(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + public SparkSubmitParameters build() { return new SparkSubmitParameters(className, config, extraParameters); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 1b248eb15d..85ce3c4989 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -87,4 +87,8 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; + + public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; + public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 8d5ae10e91..2bd1ae67b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -96,12 +97,19 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); } - // either empty json when the result is not available or data with status - // Fetch from Result Index - JSONObject result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - + JSONObject result; + if (asyncQueryJobMetadata.getSessionId() == null) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + result = + jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } else { + // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. + result = + jobExecutionResponseReader.getResultWithQueryId( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } // if result index document has a status, we are gonna use the status directly; otherwise, we // will use emr-s job status. // That a job is successful does not mean there is no error in execution. For example, even if @@ -230,22 +238,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - false, - dataSourceMetadata.getResultIndex()); + if (sessionManager.isEnabled()) { Session session; if (dispatchQueryRequest.getSessionId() != null) { @@ -260,7 +253,19 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // create session if not exist session = sessionManager.createSession( - new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName())); + new CreateSessionRequest( + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .className(FLINT_SESSION_CLASS_NAME) + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + tags, + dataSourceMetadata.getResultIndex(), + dataSourceMetadata.getName())); } StatementId statementId = session.submit( @@ -272,6 +277,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); } else { + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 17e3346248..ca2b2b4867 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -5,11 +5,30 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Map; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; @Data public class CreateSessionRequest { - private final StartJobRequest startJobRequest; + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder; + private final Map tags; + private final String resultIndex; private final String datasourceName; + + public StartJobRequest getStartJobRequest() { + return new StartJobRequest( + "select 1", + jobName, + applicationId, + executionRoleArn, + sparkSubmitParametersBuilder.build().toString(), + tags, + false, + resultIndex); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index e33ef4245a..4428c3b83d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -42,13 +42,17 @@ public class InteractiveSession implements Session { @Override public void open(CreateSessionRequest createSessionRequest) { try { + // append session id; + createSessionRequest + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore).apply(sessionModel); + createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -59,7 +63,8 @@ public void open(CreateSessionRequest createSessionRequest) { /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -69,7 +74,8 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = + getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -84,6 +90,7 @@ public StatementId submit(QueryRequest request) { .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) + .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) .queryId(statementId.getId()) .build(); @@ -103,7 +110,7 @@ public StatementId submit(QueryRequest request) { @Override public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore) + return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) .apply(stID.getId()) .map( model -> diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index 861d906b9b..b3bd716925 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,15 +5,32 @@ package org.opensearch.sql.spark.execution.session; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import lombok.Data; import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { + public static final int PREFIX_LEN = 10; + private final String sessionId; - public static SessionId newSessionId() { - return new SessionId(RandomStringUtils.randomAlphanumeric(16)); + public static SessionId newSessionId(String datasourceName) { + return new SessionId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(sessionId); + } + + private static String decode(String sessionId) { + return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); + } + + private static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c34be7015f..c0f7bbcde8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -28,7 +28,7 @@ public class SessionManager { public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() - .sessionId(newSessionId()) + .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); @@ -37,7 +37,8 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); + Optional model = + StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index a4da957f12..bd5d14c603 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -32,8 +33,10 @@ public enum SessionState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static SessionState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (SessionState ss : SessionState.values()) { + if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid session state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java index dd179a1dc5..10b9ce7bd5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -5,9 +5,7 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Arrays; -import java.util.Map; -import java.util.stream.Collectors; +import java.util.Locale; import lombok.Getter; @Getter @@ -20,13 +18,11 @@ public enum SessionType { this.sessionType = sessionType; } - private static Map TYPES = - Arrays.stream(SessionType.values()) - .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); - public static SessionType fromString(String key) { - if (TYPES.containsKey(key)) { - return TYPES.get(key); + for (SessionType sType : SessionType.values()) { + if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) { + return sType; + } } throw new IllegalArgumentException("Invalid session type: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 8fcedb5fca..d84c91bdb8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -32,6 +32,7 @@ public class Statement { private final String jobId; private final StatementId statementId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final StateStore stateStore; @@ -42,8 +43,16 @@ public class Statement { public void open() { try { statementModel = - submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); - statementModel = createStatement(stateStore).apply(statementModel); + submitStatement( + sessionId, + applicationId, + jobId, + statementId, + langType, + datasourceName, + query, + queryId); + statementModel = createStatement(stateStore, datasourceName).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -61,7 +70,8 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + updateStatementState(stateStore, statementModel.getDatasourceName()) + .apply(this.statementModel, StatementState.CANCELLED); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -69,7 +79,9 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + getStatement(stateStore, statementModel.getDatasourceName()) + .apply(statementModel.getId()) + .orElse(this.statementModel); String errorMsg = String.format( "cancel statement failed. current statementState: %s " + "statement: %s.", diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index c7f681c541..2a1043bf73 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; @@ -45,6 +46,7 @@ public class StatementModel extends StateModel { private final String applicationId; private final String jobId; private final LangType langType; + private final String datasourceName; private final String query; private final String queryId; private final long submitTime; @@ -65,6 +67,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(APPLICATION_ID, applicationId) .field(JOB_ID, jobId) .field(LANG, langType.getText()) + .field(DATASOURCE_NAME, datasourceName) .field(QUERY, query) .field(QUERY_ID, queryId) .field(SUBMIT_TIME, submitTime) @@ -82,6 +85,7 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -101,6 +105,7 @@ public static StatementModel copyWithState( .applicationId(copy.applicationId) .jobId(copy.jobId) .langType(copy.langType) + .datasourceName(copy.datasourceName) .query(copy.query) .queryId(copy.queryId) .submitTime(copy.submitTime) @@ -143,6 +148,9 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case LANG: builder.langType(LangType.fromString(parser.text())); break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; case QUERY: builder.query(parser.text()); break; @@ -168,6 +176,7 @@ public static StatementModel submitStatement( String jobId, StatementId statementId, LangType langType, + String datasourceName, String query, String queryId) { return builder() @@ -178,6 +187,7 @@ public static StatementModel submitStatement( .applicationId(applicationId) .jobId(jobId) .langType(langType) + .datasourceName(datasourceName) .query(query) .queryId(queryId) .submitTime(System.currentTimeMillis()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 33f7f5e831..48978ff8f9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statement; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -30,8 +31,10 @@ public enum StatementState { .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); public static StatementState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); + for (StatementState ss : StatementState.values()) { + if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) { + return ss; + } } throw new IllegalArgumentException("Invalid statement state: " + key); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index bd72b17353..a36ee3ef45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -5,15 +5,22 @@ package org.opensearch.sql.spark.execution.statestore; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.Locale; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import lombok.RequiredArgsConstructor; +import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; @@ -22,6 +29,9 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -33,15 +43,29 @@ import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; +/** + * State Store maintain the state of Session and Statement. State State create/update/get doc on + * index regardless user FGAC permissions. + */ @RequiredArgsConstructor public class StateStore { + public static String SETTINGS_FILE_NAME = "query_execution_request_settings.yml"; + public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; + public static Function DATASOURCE_TO_REQUEST_INDEX = + datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); + private static final Logger LOG = LogManager.getLogger(); - private final String indexName; private final Client client; + private final ClusterService clusterService; - protected T create(T st, StateModel.CopyBuilder builder) { + protected T create( + T st, StateModel.CopyBuilder builder, String indexName) { try { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); + } IndexRequest indexRequest = new IndexRequest(indexName) .id(st.getId()) @@ -50,48 +74,60 @@ protected T create(T st, StateModel.CopyBuilder builde .setIfPrimaryTerm(st.getPrimaryTerm()) .create(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - st.getId(), - indexResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + ; + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } - protected Optional get(String sid, StateModel.FromXContent builder) { + protected Optional get( + String sid, StateModel.FromXContent builder, String indexName) { try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + createIndex(indexName); return Optional.empty(); } + GetRequest getRequest = new GetRequest().index(indexName).id(sid).refresh(true); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } } catch (IOException e) { throw new RuntimeException(e); } } protected T updateState( - T st, S state, StateModel.StateCopyBuilder builder) { + T st, S state, StateModel.StateCopyBuilder builder, String indexName) { try { T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); UpdateRequest updateRequest = @@ -103,47 +139,110 @@ protected T updateState( .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) .fetchSource(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - UpdateResponse updateResponse = client.update(updateRequest).actionGet(); - if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { - LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed update doc. id: %s, error: %s", - st.getId(), - updateResponse.getResult().getLowercase())); + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of( + model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } } } catch (IOException e) { throw new RuntimeException(e); } } + private void createIndex(String indexName) { + try { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest + .mapping(loadConfigFromResource(MAPPING_FILE_NAME), XContentType.YAML) + .settings(loadConfigFromResource(SETTINGS_FILE_NAME), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", indexName); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + indexName + " index:: " + e.getMessage()); + } + } + + private String loadConfigFromResource(String fileName) throws IOException { + InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); + return IOUtils.toString(fileStream, StandardCharsets.UTF_8); + } + /** Helper Functions */ - public static Function createStatement(StateStore stateStore) { - return (st) -> stateStore.create(st, StatementModel::copy); + public static Function createStatement( + StateStore stateStore, String datasourceName) { + return (st) -> + stateStore.create( + st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getStatement(StateStore stateStore) { - return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + public static Function> getStatement( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } public static BiFunction updateStatementState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function createSession( + StateStore stateStore, String datasourceName) { + return (session) -> + stateStore.create( + session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function createSession(StateStore stateStore) { - return (session) -> stateStore.create(session, SessionModel::of); + public static Function> getSession( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + public static Function> searchSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); } public static BiFunction updateSessionState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + StateStore stateStore, String datasourceName) { + return (old, state) -> + stateStore.updateState( + old, + state, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { + String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + return () -> stateStore.createIndex(indexName); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java index d3cbd68dce..2614992463 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -39,6 +39,10 @@ public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex); } + public JSONObject getResultWithQueryId(String queryId, String resultIndex) { + return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex); + } + private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) { SearchRequest searchRequest = new SearchRequest(); String searchResultIndex = resultIndex == null ? SPARK_RESPONSE_BUFFER_INDEX_NAME : resultIndex; diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml new file mode 100644 index 0000000000..87bd927e6e --- /dev/null +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -0,0 +1,40 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + type: + type: keyword + state: + type: keyword + statementId: + type: keyword + applicationId: + type: keyword + sessionId: + type: keyword + sessionType: + type: keyword + error: + type: text + lang: + type: keyword + query: + type: text + dataSourceName: + type: keyword + submitTime: + type: date + format: strict_date_time||epoch_millis + jobId: + type: keyword + lastUpdateTime: + type: date + format: strict_date_time||epoch_millis + queryId: + type: keyword diff --git a/spark/src/main/resources/query_execution_request_settings.yml b/spark/src/main/resources/query_execution_request_settings.yml new file mode 100644 index 0000000000..da2bf07bf1 --- /dev/null +++ b/spark/src/main/resources/query_execution_request_settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java new file mode 100644 index 0000000000..3eb8958eb2 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -0,0 +1,374 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; +import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import lombok.Getter; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.test.OpenSearchIntegTestCase; + +public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { + public static final String DATASOURCE = "mys3"; + + private ClusterService clusterService; + private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; + private DataSourceServiceImpl dataSourceService; + private StateStore stateStore; + private ClusterSettings clusterSettings; + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TestSettingPlugin.class); + } + + public static class TestSettingPlugin extends Plugin { + @Override + public List> getSettings() { + return OpenSearchSettings.pluginSettings(); + } + } + + @Before + public void setup() { + clusterService = clusterService(); + clusterSettings = clusterService.getClusterSettings(); + pluginSettings = new OpenSearchSettings(clusterSettings); + client = (NodeClient) cluster().client(); + dataSourceService = createDataSourceService(); + dataSourceService.createDataSource( + new DataSourceMetadata( + DATASOURCE, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null)); + stateStore = new StateStore(client, clusterService); + createIndex(SPARK_RESPONSE_BUFFER_INDEX_NAME); + } + + @After + public void clean() { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) + .get(); + } + + @Test + public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // disable session + enableSession(false); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); + emrsClient.startJobRunCalled(1); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals("RUNNING", asyncQueryResults.getStatus()); + emrsClient.getJobRunResultCalled(1); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void createAsyncQueryCreateJobWithCorrectParameters() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertNull(response.getSessionId()); + assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); + assertFalse( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertFalse( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + + // enable session + enableSession(true); + response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); + assertTrue( + params.contains( + String.format("%s=%s", FLINT_JOB_REQUEST_INDEX, SPARK_REQUEST_BUFFER_INDEX_NAME))); + assertTrue( + params.contains(String.format("%s=%s", FLINT_JOB_SESSION_ID, response.getSessionId()))); + } + + @Test + public void withSessionCreateAsyncQueryThenGetResultThenCancel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); + + // 3. cancel async query. + String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + assertEquals(response.getQueryId(), cancelQueryId); + } + + @Test + public void reuseSessionWhenCreateAsyncQuery() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + assertNotEquals(first.getQueryId(), second.getQueryId()); + // one session doc. + assertEquals( + 1, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", SESSION_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + // two statement docs has same sessionId. + assertEquals( + 2, + search( + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("type", STATEMENT_DOC_TYPE)) + .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); + + Optional firstModel = + getStatement(stateStore, DATASOURCE).apply(first.getQueryId()); + assertTrue(firstModel.isPresent()); + assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); + assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); + assertEquals(first.getQueryId(), firstModel.get().getQueryId()); + Optional secondModel = + getStatement(stateStore, DATASOURCE).apply(second.getQueryId()); + assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); + assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); + assertEquals(second.getQueryId(), secondModel.get().getQueryId()); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "1234567890"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> {}); + } + + private AsyncQueryExecutorService createAsyncQueryExecutorService( + EMRServerlessClient emrServerlessClient) { + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader, + new FlintIndexMetadataReaderImpl(client), + client, + new SessionManager( + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + this::sparkExecutionEngineConfig); + } + + public static class LocalEMRSClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + private int getJobResult = 0; + + @Getter private StartJobRequest jobRequest; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + jobRequest = startJobRequest; + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + getJobResult++; + JobRun jobRun = new JobRun(); + jobRun.setState("RUNNING"); + return new GetJobRunResult().withJobRun(jobRun); + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return new CancelJobRunResult().withJobRunId(jobId); + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + + public void getJobRunResultCalled(int expectedTimes) { + assertEquals(expectedTimes, getJobResult); + } + } + + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { + return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); + } + + public void enableSession(boolean enabled) { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder() + .put(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey(), enabled) + .build()) + .get(); + } + + int search(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + + return searchResponse.getHits().getHits().length; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 58fe626dae..15211dec01 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -726,7 +726,7 @@ void testGetQueryResponseWithSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -740,7 +740,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(Optional.empty()).when(sessionManager).getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -759,7 +759,7 @@ void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.empty()).when(session).get(any()); doReturn(new JSONObject()) .when(jobExecutionResponseReader) - .getResultFromOpensearchIndex(eq(MOCK_STATEMENT_ID), any()); + .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); IllegalArgumentException exception = Assertions.assertThrows( diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 429c970365..06a8d8c73c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -8,10 +8,12 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -20,15 +22,17 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ -public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { +public class InteractiveSessionTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; @@ -38,20 +42,21 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() - .sessionId(SessionId.newSessionId()) + .sessionId(SessionId.newSessionId(DS_NAME)) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -59,7 +64,7 @@ public void openCloseSession() { // open session TestSession testSession = testSession(session, stateStore); testSession - .open(new CreateSessionRequest(startJobRequest, "datasource")) + .open(createSessionRequest()) .assertSessionState(NOT_STARTED) .assertAppId("appId") .assertJobId("jobId"); @@ -72,14 +77,14 @@ public void openCloseSession() { @Test public void openSessionFailedConflict() { - SessionId sessionId = new SessionId("duplicate-session-id"); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); InteractiveSession duplicateSession = InteractiveSession.builder() @@ -89,21 +94,20 @@ public void openSessionFailedConflict() { .build(); IllegalStateException exception = assertThrows( - IllegalStateException.class, - () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); - assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); + assertEquals("session already exist. " + sessionId, exception.getMessage()); } @Test public void closeNotExistSession() { - SessionId sessionId = SessionId.newSessionId(); + SessionId sessionId = SessionId.newSessionId(DS_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) .serverlessClient(emrsClient) .build(); - session.open(new CreateSessionRequest(startJobRequest, "datasource")); + session.open(createSessionRequest()); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -116,7 +120,7 @@ public void closeNotExistSession() { public void sessionManagerCreateSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); TestSession testSession = testSession(session, stateStore); testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); @@ -126,8 +130,7 @@ public void sessionManagerCreateSession() { public void sessionManagerGetSession() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Session session = - sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); assertTrue(managerSession.isPresent()); @@ -139,7 +142,8 @@ public void sessionManagerGetSessionNotExist() { SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting(false)); - Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + Optional managerSession = + sessionManager.getSession(SessionId.newSessionId("no-exist")); assertTrue(managerSession.isEmpty()); } @@ -156,7 +160,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore).apply(session.getSessionModel().getId()); + getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -184,6 +188,17 @@ public TestSession close() { } } + public static CreateSessionRequest createSessionRequest() { + return new CreateSessionRequest( + "jobName", + "appId", + "arn", + SparkSubmitParameters.Builder.builder(), + ImmutableMap.of(), + "resultIndex", + DS_NAME); + } + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 214bcb8258..ff3ddd1bef 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,15 +5,16 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; -import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.junit.After; @@ -21,8 +22,6 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -30,27 +29,27 @@ import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; -public class StatementTest extends OpenSearchSingleNodeTestCase { +public class StatementTest extends OpenSearchIntegTestCase { - private static final String indexName = "mockindex"; + private static final String DS_NAME = "mys3"; + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); - private StartJobRequest startJobRequest; private StateStore stateStore; private InteractiveSessionTest.TestEMRServerlessClient emrsClient = new InteractiveSessionTest.TestEMRServerlessClient(); @Before public void setup() { - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); + stateStore = new StateStore(client(), clusterService()); } @After public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + if (clusterService().state().routingTable().hasIndex(indexName)) { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } } @Test @@ -62,6 +61,7 @@ public void openThenCancelStatement() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -87,6 +87,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -101,6 +102,7 @@ public void openFailedBecauseConflict() { .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -119,13 +121,14 @@ public void cancelNotExistStatement() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) .build(); st.open(); - client().delete(new DeleteRequest(indexName, stId.getId())); + client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( @@ -143,6 +146,7 @@ public void cancelFailedBecauseOfConflict() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -150,7 +154,7 @@ public void cancelFailedBecauseOfConflict() { st.open(); StatementModel running = - updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); assertEquals(StatementState.CANCELLED, running.getStatementState()); @@ -172,6 +176,7 @@ public void cancelRunningStatementFailed() { .jobId("jobId") .statementId(stId) .langType(LangType.SQL) + .datasourceName(DS_NAME) .query("query") .queryId("statementId") .stateStore(stateStore) @@ -198,10 +203,10 @@ public void cancelRunningStatementFailed() { public void submitStatementInRunningSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -211,7 +216,7 @@ public void submitStatementInRunningSession() { public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); @@ -221,9 +226,9 @@ public void submitStatementInNotStartedState() { public void failToSubmitStatementInDeadState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = assertThrows( @@ -239,9 +244,9 @@ public void failToSubmitStatementInDeadState() { public void failToSubmitStatementInFailState() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = assertThrows( @@ -257,7 +262,7 @@ public void failToSubmitStatementInFailState() { public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -275,7 +280,7 @@ public void newStatementFieldAssert() { public void failToSubmitStatementInDeletedSession() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // other's delete session client() @@ -293,9 +298,9 @@ public void failToSubmitStatementInDeletedSession() { public void getStatementSuccess() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); @@ -308,9 +313,9 @@ public void getStatementSuccess() { public void getStatementNotExist() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); Optional statement = session.get(StatementId.newStatementId()); assertFalse(statement.isPresent()); @@ -328,7 +333,8 @@ public static TestStatement testStatement(Statement st, StateStore stateStore) { public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -338,7 +344,8 @@ public TestStatement assertSessionState(StatementState expected) { public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + Optional model = + getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this; From b30d3c98691887952299ba0361e266e6e60fd48e Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 23 Oct 2023 14:28:03 -0700 Subject: [PATCH 10/16] deprecated job-metadata-index (#2339) * deprecate job-metadata-index Signed-off-by: Peng Huo * upgrade log4j2 Signed-off-by: Peng Huo * update codestyle Signed-off-by: Peng Huo * upgrade log4j Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- common/build.gradle | 2 +- integ-test/build.gradle | 2 +- .../org/opensearch/sql/plugin/SQLPlugin.java | 6 +- ppl/build.gradle | 2 +- .../AsyncQueryExecutorServiceImpl.java | 3 +- ...chAsyncQueryJobMetadataStorageService.java | 161 +---------- .../spark/asyncquery/model/AsyncQueryId.java | 35 +++ .../model/AsyncQueryJobMetadata.java | 157 +++++++--- .../spark/dispatcher/AsyncQueryHandler.java | 49 ++++ .../spark/dispatcher/BatchQueryHandler.java | 50 ++++ .../dispatcher/InteractiveQueryHandler.java | 69 +++++ .../dispatcher/SparkQueryDispatcher.java | 126 ++------ .../model/DispatchQueryResponse.java | 2 + .../execution/session/InteractiveSession.java | 5 +- .../spark/execution/session/SessionId.java | 15 +- .../execution/statement/QueryRequest.java | 2 + .../execution/statement/StatementId.java | 6 +- .../execution/statestore/StateStore.java | 26 +- .../opensearch/sql/spark/utils/IDUtils.java | 25 ++ .../resources/job-metadata-index-mapping.yml | 25 -- .../resources/job-metadata-index-settings.yml | 11 - .../query_execution_request_mapping.yml | 2 + ...AsyncQueryExecutorServiceImplSpecTest.java | 6 +- .../AsyncQueryExecutorServiceImplTest.java | 30 +- ...yncQueryJobMetadataStorageServiceTest.java | 272 ++++-------------- .../dispatcher/SparkQueryDispatcherTest.java | 54 ++-- .../execution/statement/StatementTest.java | 27 +- 27 files changed, 538 insertions(+), 632 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java delete mode 100644 spark/src/main/resources/job-metadata-index-mapping.yml delete mode 100644 spark/src/main/resources/job-metadata-index-settings.yml diff --git a/common/build.gradle b/common/build.gradle index 109cad59cb..5386c32468 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -35,7 +35,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - api group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + api group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' diff --git a/integ-test/build.gradle b/integ-test/build.gradle index f2e70d9908..08b8cfb210 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -175,7 +175,7 @@ dependencies { testImplementation group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" testImplementation group: 'org.opensearch.driver', name: 'opensearch-sql-jdbc', version: System.getProperty("jdbcDriverVersion", '1.2.0.0') testImplementation group: 'org.hamcrest', name: 'hamcrest', version: '2.1' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' testImplementation project(':opensearch-sql-plugin') testImplementation project(':legacy') testImplementation('org.junit.jupiter:junit-jupiter-api:5.6.2') diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f714a8366b..3d9740d84c 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -306,8 +306,9 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, SparkExecutionEngineConfig sparkExecutionEngineConfig) { + StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); EMRServerlessClient emrServerlessClient = createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); @@ -319,8 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/ppl/build.gradle b/ppl/build.gradle index 04ad71ced6..d0d9fe3cbf 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -50,7 +50,7 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' api group: 'org.json', name: 'json', version: '20231013' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.20.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' api project(':common') api project(':core') api project(':protocol') diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 7cba2757cc..18ae47c2b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -69,13 +69,14 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( + dispatchQueryResponse.getQueryId(), sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), dispatchQueryResponse.getResultIndex(), dispatchQueryResponse.getSessionId())); return new CreateAsyncQueryResponse( - dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); + dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index a95a6ffe45..6de8c35f03 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,166 +7,31 @@ package org.opensearch.sql.spark.asyncquery; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; + import java.util.Optional; -import org.apache.commons.io.IOUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ +@RequiredArgsConstructor public class OpensearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { - public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; - private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = - "job-metadata-index-mapping.yml"; - private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = - "job-metadata-index-settings.yml"; - private static final Logger LOG = LogManager.getLogger(); - private final Client client; - private final ClusterService clusterService; - - /** - * This class implements JobMetadataStorageService interface using OpenSearch as underlying - * storage. - * - * @param client opensearch NodeClient. - * @param clusterService ClusterService. - */ - public OpensearchAsyncQueryJobMetadataStorageService( - Client client, ClusterService clusterService) { - this.client = client; - this.clusterService = clusterService; - } + private final StateStore stateStore; @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - } - IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); - indexRequest.id(asyncQueryJobMetadata.getJobId()); - indexRequest.opType(DocWriteRequest.OpType.CREATE); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - ActionFuture indexResponseActionFuture; - IndexResponse indexResponse; - try (ThreadContext.StoredContext storedContext = - client.threadPool().getThreadContext().stashContext()) { - indexRequest.source(AsyncQueryJobMetadata.convertToXContent(asyncQueryJobMetadata)); - indexResponseActionFuture = client.index(indexRequest); - indexResponse = indexResponseActionFuture.actionGet(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("JobMetadata : {} successfully created", asyncQueryJobMetadata.getJobId()); - } else { - throw new RuntimeException( - "Saving job metadata information failed with result : " - + indexResponse.getResult().getLowercase()); - } + AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); + createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); } @Override - public Optional getJobMetadata(String jobId) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - return Optional.empty(); - } - return searchInJobMetadataIndex(QueryBuilders.termQuery("jobId.keyword", jobId)).stream() - .findFirst(); - } - - private void createJobMetadataIndex() { - try { - InputStream mappingFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); - InputStream settingsFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); - CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); - createIndexRequest - .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) - .settings( - IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); - ActionFuture createIndexResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); - } - CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); - if (createIndexResponse.isAcknowledged()) { - LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); - } else { - throw new RuntimeException("Index creation is not acknowledged."); - } - } catch (Throwable e) { - throw new RuntimeException( - "Internal server error while creating" - + JOB_METADATA_INDEX - + " index:: " - + e.getMessage()); - } - } - - private List searchInJobMetadataIndex(QueryBuilder query) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(JOB_METADATA_INDEX); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(query); - searchSourceBuilder.size(1); - searchRequest.source(searchSourceBuilder); - // https://github.com/opensearch-project/sql/issues/1801. - searchRequest.preference("_primary_first"); - ActionFuture searchResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - searchResponseActionFuture = client.search(searchRequest); - } - SearchResponse searchResponse = searchResponseActionFuture.actionGet(); - if (searchResponse.status().getStatus() != 200) { - throw new RuntimeException( - "Fetching job metadata information failed with status : " + searchResponse.status()); - } else { - List list = new ArrayList<>(); - for (SearchHit searchHit : searchResponse.getHits().getHits()) { - String sourceAsString = searchHit.getSourceAsString(); - AsyncQueryJobMetadata asyncQueryJobMetadata; - try { - asyncQueryJobMetadata = AsyncQueryJobMetadata.toJobMetadata(sourceAsString); - } catch (IOException e) { - throw new RuntimeException(e); - } - list.add(asyncQueryJobMetadata); - } - return list; - } + public Optional getJobMetadata(String qid) { + AsyncQueryId queryId = new AsyncQueryId(qid); + return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) + .apply(queryId.docId()); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java new file mode 100644 index 0000000000..b99ebe0e8c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + +import lombok.Data; + +/** Async query id. */ +@Data +public class AsyncQueryId { + private final String id; + + public static AsyncQueryId newAsyncQueryId(String datasourceName) { + return new AsyncQueryId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(id); + } + + /** OpenSearch DocId. */ + public String docId() { + return "qid" + id; + } + + @Override + public String toString() { + return "asyncQueryId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b80fefa173..3c59403661 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -8,37 +8,83 @@ package org.opensearch.sql.spark.asyncquery.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.statement.StatementModel.QUERY_ID; import com.google.gson.Gson; import java.io.IOException; -import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; +import lombok.SneakyThrows; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** This class models all the metadata required for a job. */ @Data -@AllArgsConstructor -@EqualsAndHashCode -public class AsyncQueryJobMetadata { - private String applicationId; - private String jobId; - private boolean isDropIndexQuery; - private String resultIndex; +@EqualsAndHashCode(callSuper = false) +public class AsyncQueryJobMetadata extends StateModel { + public static final String TYPE_JOBMETA = "jobmeta"; + + private final AsyncQueryId queryId; + private final String applicationId; + private final String jobId; + private final boolean isDropIndexQuery; + private final String resultIndex; // optional sessionId. - private String sessionId; + private final String sessionId; + + @EqualsAndHashCode.Exclude private final long seqNo; + @EqualsAndHashCode.Exclude private final long primaryTerm; - public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { + public AsyncQueryJobMetadata( + AsyncQueryId queryId, String applicationId, String jobId, String resultIndex) { + this( + queryId, + applicationId, + jobId, + false, + resultIndex, + null, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId) { + this( + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId, + long seqNo, + long primaryTerm) { + this.queryId = queryId; this.applicationId = applicationId; this.jobId = jobId; - this.isDropIndexQuery = false; + this.isDropIndexQuery = isDropIndexQuery; this.resultIndex = resultIndex; - this.sessionId = null; + this.sessionId = sessionId; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; } @Override @@ -49,39 +95,36 @@ public String toString() { /** * Converts JobMetadata to XContentBuilder. * - * @param metadata metadata. * @return XContentBuilder {@link XContentBuilder} * @throws Exception Exception. */ - public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.field("jobId", metadata.getJobId()); - builder.field("applicationId", metadata.getApplicationId()); - builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); - builder.field("resultIndex", metadata.getResultIndex()); - builder.field("sessionId", metadata.getSessionId()); - builder.endObject(); + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(QUERY_ID, queryId.getId()) + .field("type", TYPE_JOBMETA) + .field("jobId", jobId) + .field("applicationId", applicationId) + .field("isDropIndexQuery", isDropIndexQuery) + .field("resultIndex", resultIndex) + .field("sessionId", sessionId) + .endObject(); return builder; } - /** - * Converts json string to DataSourceMetadata. - * - * @param json jsonstring. - * @return jobmetadata {@link AsyncQueryJobMetadata} - * @throws java.io.IOException IOException. - */ - public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOException { - try (XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - json)) { - return toJobMetadata(parser); - } + /** copy builder. update seqNo and primaryTerm */ + public static AsyncQueryJobMetadata copy( + AsyncQueryJobMetadata copy, long seqNo, long primaryTerm) { + return new AsyncQueryJobMetadata( + copy.getQueryId(), + copy.getApplicationId(), + copy.getJobId(), + copy.isDropIndexQuery(), + copy.getResultIndex(), + copy.getSessionId(), + seqNo, + primaryTerm); } /** @@ -91,17 +134,23 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio * @return JobMetadata {@link AsyncQueryJobMetadata} * @throws IOException IOException. */ - public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { + @SneakyThrows + public static AsyncQueryJobMetadata fromXContent( + XContentParser parser, long seqNo, long primaryTerm) { + AsyncQueryId queryId = null; String jobId = null; String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { + case QUERY_ID: + queryId = new AsyncQueryId(parser.textOrNull()); + break; case "jobId": jobId = parser.textOrNull(); break; @@ -117,6 +166,8 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "sessionId": sessionId = parser.textOrNull(); break; + case "type": + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -125,6 +176,18 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws throw new IllegalArgumentException("jobId and applicationId are required fields."); } return new AsyncQueryJobMetadata( - applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + seqNo, + primaryTerm); + } + + @Override + public String getId() { + return queryId.docId(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java new file mode 100644 index 0000000000..77a0e1cd09 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +/** Process async query request. */ +public abstract class AsyncQueryHandler { + + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + if (asyncQueryJobMetadata.isDropIndexQuery()) { + return SparkQueryDispatcher.DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()) + .result(); + } + + JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata); + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, mark failed + String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + return result; + } else { + return getResponseFromExecutor(asyncQueryJobMetadata); + } + } + + protected abstract JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + protected abstract JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java new file mode 100644 index 0000000000..8a582278e1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class BatchQueryHandler extends AsyncQueryHandler { + private final EMRServerlessClient emrServerlessClient; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + return jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return asyncQueryJobMetadata.getQueryId().getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java new file mode 100644 index 0000000000..24ea1528c8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class InteractiveQueryHandler extends AsyncQueryHandler { + private final SessionManager sessionManager; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + return jobExecutionResponseReader.getResultWithQueryId( + queryId, asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); + StatementState statementState = statement.getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); + return queryId; + } + + private Statement getStatementByQueryId(String sid, String qid) { + SessionId sessionId = new SessionId(sid); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(qid); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + return statement.get(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 2bd1ae67b9..882f2663d9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -10,8 +10,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -33,6 +31,7 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -46,9 +45,6 @@ import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statement.QueryRequest; -import org.opensearch.sql.spark.execution.statement.Statement; -import org.opensearch.sql.spark.execution.statement.StatementId; -import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -92,97 +88,22 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - // todo. refactor query process logic in plugin. - if (asyncQueryJobMetadata.isDropIndexQuery()) { - return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); - } - - JSONObject result; - if (asyncQueryJobMetadata.getSessionId() == null) { - // either empty json when the result is not available or data with status - // Fetch from Result Index - result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + if (asyncQueryJobMetadata.getSessionId() != null) { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } else { - // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. - result = - jobExecutionResponseReader.getResultWithQueryId( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } - // if result index document has a status, we are gonna use the status directly; otherwise, we - // will use emr-s job status. - // That a job is successful does not mean there is no error in execution. For example, even if - // result - // index mapping is incorrect, we still write query result and let the job finish. - // That a job is running does not mean the status is running. For example, index/streaming Query - // is a - // long-running job which runs forever. But we need to return success from the result index - // immediately. - if (result.has(DATA_FIELD)) { - JSONObject items = result.getJSONObject(DATA_FIELD); - - // If items have STATUS_FIELD, use it; otherwise, mark failed - String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); - result.put(STATUS_FIELD, status); - - // If items have ERROR_FIELD, use it; otherwise, set empty string - String error = items.optString(ERROR_FIELD, ""); - result.put(ERROR_FIELD, error); - } else { - if (asyncQueryJobMetadata.getSessionId() != null) { - SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); - Optional session = sessionManager.getSession(sessionId); - if (session.isPresent()) { - // todo, statementId == jobId if statement running in session. - StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); - Optional statement = session.get().get(statementId); - if (statement.isPresent()) { - StatementState statementState = statement.get().getStatementState(); - result.put(STATUS_FIELD, statementState.getState()); - result.put(ERROR_FIELD, ""); - } else { - throw new IllegalArgumentException("no statement found. " + statementId); - } - } else { - throw new IllegalArgumentException("no session found. " + sessionId); - } - } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); - } - } - - return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); - Optional session = sessionManager.getSession(sessionId); - if (session.isPresent()) { - // todo, statementId == jobId if statement running in session. - StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); - Optional statement = session.get().get(statementId); - if (statement.isPresent()) { - statement.get().cancel(); - return statementId.getId(); - } else { - throw new IllegalArgumentException("no statement found. " + statementId); - } - } else { - throw new IllegalArgumentException("no session found. " + sessionId); - } + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); } else { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); } } @@ -229,12 +150,18 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + AsyncQueryId queryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); @@ -267,12 +194,12 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName())); } - StatementId statementId = - session.submit( - new QueryRequest( - dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + session.submit( + new QueryRequest( + queryId, dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); return new DispatchQueryResponse( - statementId.getId(), + queryId, + session.getSessionModel().getJobId(), false, dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); @@ -294,7 +221,8 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ false, dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + return new DispatchQueryResponse( + queryId, jobId, false, dataSourceMetadata.getResultIndex(), null); } } @@ -325,7 +253,11 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + new DropIndexResult(status).toJobId(), + true, + dataSourceMetadata.getResultIndex(), + null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 893446c617..e44379daff 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -2,10 +2,12 @@ import lombok.AllArgsConstructor; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Data @AllArgsConstructor public class DispatchQueryResponse { + private AsyncQueryId queryId; private String jobId; private boolean isDropIndexQuery; private String resultIndex; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 4428c3b83d..a2e7cfe6ee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -81,7 +81,8 @@ public StatementId submit(QueryRequest request) { } else { sessionModel = model.get(); if (!END_STATE.contains(sessionModel.getSessionState())) { - StatementId statementId = newStatementId(); + String qid = request.getQueryId().getId(); + StatementId statementId = newStatementId(qid); Statement st = Statement.builder() .sessionId(sessionId) @@ -92,7 +93,7 @@ public StatementId submit(QueryRequest request) { .langType(LangType.SQL) .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) - .queryId(statementId.getId()) + .queryId(qid) .build(); st.open(); return statementId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index b3bd716925..c85e4dd35c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,10 +5,10 @@ package org.opensearch.sql.spark.execution.session; -import java.nio.charset.StandardCharsets; -import java.util.Base64; +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + import lombok.Data; -import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { @@ -24,15 +24,6 @@ public String getDataSourceName() { return decode(sessionId); } - private static String decode(String sessionId) { - return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); - } - - private static String encode(String datasourceName) { - String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; - return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); - } - @Override public String toString() { return "sessionId=" + sessionId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java index 10061404ca..c365265224 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -6,10 +6,12 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.rest.model.LangType; @Data public class QueryRequest { + private final AsyncQueryId queryId; private final LangType langType; private final String query; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index d9381ad45f..33284c4b3d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -6,14 +6,14 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; -import org.apache.commons.lang3.RandomStringUtils; @Data public class StatementId { private final String id; - public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.randomAlphanumeric(16)); + // construct statementId from queryId. + public static StatementId newStatementId(String qid) { + return new StatementId(qid); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index a36ee3ef45..6546d303fb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; @@ -53,7 +54,6 @@ public class StateStore { public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; public static Function DATASOURCE_TO_REQUEST_INDEX = datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); private static final Logger LOG = LogManager.getLogger(); @@ -77,7 +77,6 @@ protected T create( try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) { IndexResponse indexResponse = client.index(indexRequest).actionGet(); - ; if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { LOG.debug("Successfully created doc. id: {}", st.getId()); return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); @@ -227,10 +226,6 @@ public static Function> getSession( docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> searchSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); - } - public static BiFunction updateSessionState( StateStore stateStore, String datasourceName) { return (old, state) -> @@ -241,8 +236,21 @@ public static BiFunction updateSession DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { - String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - return () -> stateStore.createIndex(indexName); + public static Function createJobMetaData( + StateStore stateStore, String datasourceName) { + return (jobMetadata) -> + stateStore.create( + jobMetadata, + AsyncQueryJobMetadata::copy, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> getJobMetaData( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, + AsyncQueryJobMetadata::fromXContent, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java new file mode 100644 index 0000000000..438d2342b4 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.RandomStringUtils; + +@UtilityClass +public class IDUtils { + public static final int PREFIX_LEN = 10; + + public static String decode(String id) { + return new String(Base64.getDecoder().decode(id)).substring(PREFIX_LEN); + } + + public static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml deleted file mode 100644 index 3a39b989a2..0000000000 --- a/spark/src/main/resources/job-metadata-index-mapping.yml +++ /dev/null @@ -1,25 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Schema file for the .ql-job-metadata index -# Also "dynamic" is set to "false" so that other fields can be added. -dynamic: false -properties: - jobId: - type: text - fields: - keyword: - type: keyword - applicationId: - type: text - fields: - keyword: - type: keyword - resultIndex: - type: text - fields: - keyword: - type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml deleted file mode 100644 index be93f4645c..0000000000 --- a/spark/src/main/resources/job-metadata-index-settings.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Settings file for the .ql-job-metadata index -index: - number_of_shards: "1" - auto_expand_replicas: "0-2" - number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml index 87bd927e6e..fbe90a1cba 100644 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -8,6 +8,8 @@ # Also "dynamic" is set to "false" so that other fields can be added. dynamic: false properties: + version: + type: keyword type: type: keyword state: diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 3eb8958eb2..1ee119df78 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -284,8 +284,9 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( EMRServerlessClient emrServerlessClient) { + StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( @@ -295,8 +296,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 0d4e280b61..2ed316795f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; @@ -29,6 +30,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -47,6 +49,7 @@ public class AsyncQueryExecutorServiceImplTest { private AsyncQueryExecutorService jobExecutorService; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @BeforeEach void setUp() { @@ -78,11 +81,12 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); + .storeJobMetadata( + new AsyncQueryJobMetadata(QUERY_ID, "00fd775baqpu4g0p", EMR_JOB_ID, null)); verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); verify(sparkQueryDispatcher, times(1)) .dispatch( @@ -93,7 +97,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME)); - Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); } @Test @@ -107,7 +111,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( @@ -139,11 +143,13 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -157,11 +163,13 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -208,9 +216,11 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); when(sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 7288fd3fc2..de0caf5589 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,242 +5,70 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService.JOB_METADATA_INDEX; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import java.util.Optional; -import org.apache.lucene.search.TotalHits; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; -import org.mockito.ArgumentMatchers; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.test.OpenSearchIntegTestCase; -@ExtendWith(MockitoExtension.class) -public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest { +public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest + extends OpenSearchIntegTestCase { - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ClusterService clusterService; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private SearchResponse searchResponse; - - @Mock private ActionFuture searchResponseActionFuture; - @Mock private ActionFuture createIndexResponseActionFuture; - @Mock private ActionFuture indexResponseActionFuture; - @Mock private IndexResponse indexResponse; - @Mock private SearchHit searchHit; - - @InjectMocks + public static final String DS_NAME = "mys3"; + private static final String MOCK_SESSION_ID = "sessionId"; + private static final String MOCK_RESULT_INDEX = "resultIndex"; private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; - @Test - public void testStoreJobMetadata() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithOutCreatingIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.TRUE); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithException() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())) - .thenThrow(new RuntimeException("error while indexing")); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithIndexCreationFailed() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Internal server error while creating.ql-job-metadata index:: " - + "Index creation is not acknowledged.", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataFailedWithNotFoundResponse() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Saving job metadata information failed with result : not_found", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testGetJobMetadata() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null); - Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); - - Optional jobMetadataOptional = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - Assertions.assertTrue(jobMetadataOptional.isPresent()); - Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); - Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + @Before + public void setup() { + opensearchJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService( + new StateStore(client(), clusterService())); } @Test - public void testGetJobMetadataWith404SearchResponse() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); - - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); - Assertions.assertEquals( - "Fetching job metadata information failed with status : NOT_FOUND", - runtimeException.getMessage()); - } - - @Test - public void testGetJobMetadataWithParsingFailed() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); - - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + public void testStoreJobMetadata() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + MOCK_RESULT_INDEX); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertFalse(actual.get().isDropIndexQuery()); + assertNull(actual.get().getSessionId()); } @Test - public void testGetJobMetadataWithNoIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - - Optional jobMetadata = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - - Assertions.assertFalse(jobMetadata.isPresent()); + public void testStoreJobMetadataWithResultExtraData() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + true, + MOCK_RESULT_INDEX, + MOCK_SESSION_ID); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertTrue(actual.get().isDropIndexQuery()); + assertEquals("resultIndex", actual.get().getResultIndex()); + assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 15211dec01..4acccae0e2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; @@ -19,6 +20,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -47,7 +49,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -58,6 +59,7 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; @@ -86,19 +88,22 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Mock private FlintIndexMetadataReader flintIndexMetadataReader; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + @Mock(answer = RETURNS_DEEP_STUBS) private Client openSearchClient; @Mock private FlintIndexMetadata flintIndexMetadata; @Mock private SessionManager sessionManager; - @Mock private Session session; + @Mock(answer = RETURNS_DEEP_STUBS) + private Session session; @Mock private Statement statement; private SparkQueryDispatcher sparkQueryDispatcher; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + @Captor ArgumentCaptor startJobRequestArgumentCaptor; @BeforeEach @@ -285,6 +290,7 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(session).when(sessionManager).createSession(any()); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -292,7 +298,7 @@ void testDispatchSelectQueryCreateNewSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).getSession(any()); - Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -307,6 +313,7 @@ void testDispatchSelectQueryReuseSession() { .getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -314,7 +321,7 @@ void testDispatchSelectQueryReuseSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).createSession(any()); - Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -636,10 +643,8 @@ void testCancelJob() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = - sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test @@ -698,10 +703,8 @@ void testCancelQueryWithNoSessionId() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = - sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test @@ -712,9 +715,7 @@ void testGetQueryResponse() { // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); - JSONObject result = - sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); Assertions.assertEquals("PENDING", result.get("status")); } @@ -790,9 +791,7 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); - JSONObject result = - sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); Assertions.assertEquals( new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); @@ -827,7 +826,13 @@ void testGetQueryResponseOfDropIndex() { JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMRS_APPLICATION_ID, + jobId, + true, + null, + null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -1210,8 +1215,13 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str sessionId); } + private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null); + } + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( - String queryId, String sessionId) { - return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); + String statementId, String sessionId) { + return new AsyncQueryJobMetadata( + new AsyncQueryId(statementId), EMRS_APPLICATION_ID, EMR_JOB_ID, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index ff3ddd1bef..1e33c8a6b9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -22,6 +22,7 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -208,7 +209,7 @@ public void submitStatementInRunningSession() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); } @@ -218,7 +219,7 @@ public void submitStatementInNotStartedState() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); } @@ -231,9 +232,7 @@ public void failToSubmitStatementInDeadState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " dead", @@ -249,9 +248,7 @@ public void failToSubmitStatementInFailState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " fail", @@ -263,7 +260,7 @@ public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -288,9 +285,7 @@ public void failToSubmitStatementInDeletedSession() { .actionGet(); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); } @@ -301,7 +296,7 @@ public void getStatementSuccess() { .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -317,7 +312,7 @@ public void getStatementNotExist() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - Optional statement = session.get(StatementId.newStatementId()); + Optional statement = session.get(StatementId.newStatementId("not-exist-id")); assertFalse(statement.isPresent()); } @@ -361,4 +356,8 @@ public TestStatement cancel() { return this; } } + + private QueryRequest queryRequest() { + return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); + } } From 877c9c3db3945ecbb84573e22ddd770c332aa6dd Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Mon, 23 Oct 2023 16:10:50 -0700 Subject: [PATCH 11/16] Add missing tags and MV support (#2336) Signed-off-by: Vamsi Manohar --- common/build.gradle | 2 +- integ-test/build.gradle | 2 +- ppl/build.gradle | 2 +- .../src/main/antlr/FlintSparkSqlExtensions.g4 | 34 +++ spark/src/main/antlr/SparkSqlBase.g4 | 5 + spark/src/main/antlr/SqlBaseLexer.g4 | 1 + spark/src/main/antlr/SqlBaseParser.g4 | 1 + .../spark/data/constants/SparkConstants.java | 2 - .../dispatcher/SparkQueryDispatcher.java | 33 ++- .../spark/dispatcher/model/IndexDetails.java | 145 +++++++++--- .../sql/spark/dispatcher/model/JobType.java | 37 +++ .../sql/spark/utils/SQLQueryUtils.java | 47 ++-- .../dispatcher/SparkQueryDispatcherTest.java | 222 ++++++++++++------ .../FlintIndexMetadataReaderImplTest.java | 58 ++--- .../sql/spark/flint/IndexDetailsTest.java | 13 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 43 +++- 16 files changed, 460 insertions(+), 187 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java diff --git a/common/build.gradle b/common/build.gradle index 5386c32468..416c9ca20a 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -35,7 +35,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' - api group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' + api group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 08b8cfb210..85535da68c 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -175,7 +175,7 @@ dependencies { testImplementation group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" testImplementation group: 'org.opensearch.driver', name: 'opensearch-sql-jdbc', version: System.getProperty("jdbcDriverVersion", '1.2.0.0') testImplementation group: 'org.hamcrest', name: 'hamcrest', version: '2.1' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" testImplementation project(':opensearch-sql-plugin') testImplementation project(':legacy') testImplementation('org.junit.jupiter:junit-jupiter-api:5.6.2') diff --git a/ppl/build.gradle b/ppl/build.gradle index d0d9fe3cbf..75281e9160 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -50,7 +50,7 @@ dependencies { implementation "org.antlr:antlr4-runtime:4.7.1" implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' api group: 'org.json', name: 'json', version: '20231013' - implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.21.0' + implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api project(':common') api project(':core') api project(':protocol') diff --git a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 index e8e0264f28..c4af2779d1 100644 --- a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 +++ b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -17,6 +17,7 @@ singleStatement statement : skippingIndexStatement | coveringIndexStatement + | materializedViewStatement ; skippingIndexStatement @@ -76,6 +77,39 @@ dropCoveringIndexStatement : DROP INDEX indexName ON tableName ; +materializedViewStatement + : createMaterializedViewStatement + | showMaterializedViewStatement + | describeMaterializedViewStatement + | dropMaterializedViewStatement + ; + +createMaterializedViewStatement + : CREATE MATERIALIZED VIEW (IF NOT EXISTS)? mvName=multipartIdentifier + AS query=materializedViewQuery + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +showMaterializedViewStatement + : SHOW MATERIALIZED (VIEW | VIEWS) IN catalogDb=multipartIdentifier + ; + +describeMaterializedViewStatement + : (DESC | DESCRIBE) MATERIALIZED VIEW mvName=multipartIdentifier + ; + +dropMaterializedViewStatement + : DROP MATERIALIZED VIEW mvName=multipartIdentifier + ; + +/* + * Match all remaining tokens in non-greedy way + * so WITH clause won't be captured by this rule. + */ +materializedViewQuery + : .+? + ; + indexColTypeList : indexColType (COMMA indexColType)* ; diff --git a/spark/src/main/antlr/SparkSqlBase.g4 b/spark/src/main/antlr/SparkSqlBase.g4 index 4ac1ced5c4..533d851ba6 100644 --- a/spark/src/main/antlr/SparkSqlBase.g4 +++ b/spark/src/main/antlr/SparkSqlBase.g4 @@ -154,6 +154,7 @@ COMMA: ','; DOT: '.'; +AS: 'AS'; CREATE: 'CREATE'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; @@ -161,14 +162,18 @@ DROP: 'DROP'; EXISTS: 'EXISTS'; FALSE: 'FALSE'; IF: 'IF'; +IN: 'IN'; INDEX: 'INDEX'; INDEXES: 'INDEXES'; +MATERIALIZED: 'MATERIALIZED'; NOT: 'NOT'; ON: 'ON'; PARTITION: 'PARTITION'; REFRESH: 'REFRESH'; SHOW: 'SHOW'; TRUE: 'TRUE'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; WITH: 'WITH'; diff --git a/spark/src/main/antlr/SqlBaseLexer.g4 b/spark/src/main/antlr/SqlBaseLexer.g4 index d9128de0f5..e8b5cb012f 100644 --- a/spark/src/main/antlr/SqlBaseLexer.g4 +++ b/spark/src/main/antlr/SqlBaseLexer.g4 @@ -447,6 +447,7 @@ PIPE: '|'; CONCAT_PIPE: '||'; HAT: '^'; COLON: ':'; +DOUBLE_COLON: '::'; ARROW: '->'; FAT_ARROW : '=>'; HENT_START: '/*+'; diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 77a9108e06..84a31dafed 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -957,6 +957,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | primaryExpression DOUBLE_COLON dataType #castByColon | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first | ANY_VALUE LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #any_value diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 85ce3c4989..e8659c680c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -26,8 +26,6 @@ public class SparkConstants { public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; // TODO should be replaced with mvn jar. - public static final String FLINT_CATALOG_JAR = - "s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar"; public static final String FLINT_DEFAULT_HOST = "localhost"; public static final String FLINT_DEFAULT_PORT = "9200"; public static final String FLINT_DEFAULT_SCHEME = "http"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 882f2663d9..ff7ccf8c08 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -38,8 +38,8 @@ import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; -import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -59,9 +59,8 @@ public class SparkQueryDispatcher { public static final String INDEX_TAG_KEY = "index"; public static final String DATASOURCE_TAG_KEY = "datasource"; - public static final String SCHEMA_TAG_KEY = "schema"; - public static final String TABLE_TAG_KEY = "table"; public static final String CLUSTER_NAME_TAG_KEY = "cluster"; + public static final String JOB_TYPE_TAG_KEY = "job_type"; private EMRServerlessClient emrServerlessClient; @@ -111,6 +110,8 @@ private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryR if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) { IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + fillMissingDetails(dispatchQueryRequest, indexDetails); + if (indexDetails.isDropIndex()) { return handleDropIndexQuery(dispatchQueryRequest, indexDetails); } else { @@ -121,17 +122,29 @@ private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryR } } + // TODO: Revisit this logic. + // Currently, Spark if datasource is not provided in query. + // Spark Assumes the datasource to be catalog. + // This is required to handle drop index case properly when datasource name is not provided. + private static void fillMissingDetails( + DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + if (indexDetails.getFullyQualifiedTableName() != null + && indexDetails.getFullyQualifiedTableName().getDatasourceName() == null) { + indexDetails + .getFullyQualifiedTableName() + .setDatasourceName(dispatchQueryRequest.getDatasource()); + } + } + private DispatchQueryResponse handleIndexQuery( DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { - FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); - tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); - tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); + tags.put(INDEX_TAG_KEY, indexDetails.openSearchIndexName()); + tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), @@ -142,12 +155,12 @@ private DispatchQueryResponse handleIndexQuery( .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) - .structuredStreaming(indexDetails.getAutoRefresh()) + .structuredStreaming(indexDetails.isAutoRefresh()) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), tags, - indexDetails.getAutoRefresh(), + indexDetails.isAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse( @@ -178,6 +191,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ session = createdSession.get(); } else { // create session if not exist + tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); session = sessionManager.createSession( new CreateSessionRequest( @@ -204,6 +218,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); } else { + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java index 1cc66da9fc..42e2905e67 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java @@ -5,56 +5,129 @@ package org.opensearch.sql.spark.dispatcher.model; -import lombok.AllArgsConstructor; -import lombok.Data; +import com.google.common.base.Preconditions; import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; +import lombok.Getter; +import org.apache.commons.lang3.StringUtils; import org.opensearch.sql.spark.flint.FlintIndexType; /** Index details in an async query. */ -@Data -@AllArgsConstructor -@NoArgsConstructor +@Getter @EqualsAndHashCode public class IndexDetails { + + public static final String STRIP_CHARS = "`"; + private String indexName; private FullyQualifiedTableName fullyQualifiedTableName; // by default, auto_refresh = false; - private Boolean autoRefresh = false; + private boolean autoRefresh; private boolean isDropIndex; + // materialized view special case where + // table name and mv name are combined. + private String mvName; private FlintIndexType indexType; + private IndexDetails() {} + + public static IndexDetailsBuilder builder() { + return new IndexDetailsBuilder(); + } + + // Builder class + public static class IndexDetailsBuilder { + private final IndexDetails indexDetails; + + public IndexDetailsBuilder() { + indexDetails = new IndexDetails(); + } + + public IndexDetailsBuilder indexName(String indexName) { + indexDetails.indexName = indexName; + return this; + } + + public IndexDetailsBuilder fullyQualifiedTableName(FullyQualifiedTableName tableName) { + indexDetails.fullyQualifiedTableName = tableName; + return this; + } + + public IndexDetailsBuilder autoRefresh(Boolean autoRefresh) { + indexDetails.autoRefresh = autoRefresh; + return this; + } + + public IndexDetailsBuilder isDropIndex(boolean isDropIndex) { + indexDetails.isDropIndex = isDropIndex; + return this; + } + + public IndexDetailsBuilder mvName(String mvName) { + indexDetails.mvName = mvName; + return this; + } + + public IndexDetailsBuilder indexType(FlintIndexType indexType) { + indexDetails.indexType = indexType; + return this; + } + + public IndexDetails build() { + Preconditions.checkNotNull(indexDetails.indexType, "Index Type can't be null"); + switch (indexDetails.indexType) { + case COVERING: + Preconditions.checkNotNull( + indexDetails.indexName, "IndexName can't be null for Covering Index."); + Preconditions.checkNotNull( + indexDetails.fullyQualifiedTableName, "TableName can't be null for Covering Index."); + break; + case SKIPPING: + Preconditions.checkNotNull( + indexDetails.fullyQualifiedTableName, "TableName can't be null for Skipping Index."); + break; + case MATERIALIZED_VIEW: + Preconditions.checkNotNull(indexDetails.mvName, "Materialized view name can't be null"); + break; + } + + return indexDetails; + } + } + public String openSearchIndexName() { FullyQualifiedTableName fullyQualifiedTableName = getFullyQualifiedTableName(); - if (FlintIndexType.SKIPPING.equals(getIndexType())) { - String indexName = - "flint" - + "_" - + fullyQualifiedTableName.getDatasourceName() - + "_" - + fullyQualifiedTableName.getSchemaName() - + "_" - + fullyQualifiedTableName.getTableName() - + "_" - + getIndexType().getSuffix(); - return indexName.toLowerCase(); - } else if (FlintIndexType.COVERING.equals(getIndexType())) { - String indexName = - "flint" - + "_" - + fullyQualifiedTableName.getDatasourceName() - + "_" - + fullyQualifiedTableName.getSchemaName() - + "_" - + fullyQualifiedTableName.getTableName() - + "_" - + getIndexName() - + "_" - + getIndexType().getSuffix(); - return indexName.toLowerCase(); - } else { - throw new UnsupportedOperationException( - String.format("Unsupported Index Type : %s", getIndexType())); + String indexName = StringUtils.EMPTY; + switch (getIndexType()) { + case COVERING: + indexName = + "flint" + + "_" + + StringUtils.strip(fullyQualifiedTableName.getDatasourceName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getSchemaName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getTableName(), STRIP_CHARS) + + "_" + + StringUtils.strip(getIndexName(), STRIP_CHARS) + + "_" + + getIndexType().getSuffix(); + break; + case SKIPPING: + indexName = + "flint" + + "_" + + StringUtils.strip(fullyQualifiedTableName.getDatasourceName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getSchemaName(), STRIP_CHARS) + + "_" + + StringUtils.strip(fullyQualifiedTableName.getTableName(), STRIP_CHARS) + + "_" + + getIndexType().getSuffix(); + break; + case MATERIALIZED_VIEW: + indexName = "flint" + "_" + StringUtils.strip(getMvName(), STRIP_CHARS).toLowerCase(); + break; } + return indexName.toLowerCase(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java new file mode 100644 index 0000000000..01f5f422e9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +public enum JobType { + INTERACTIVE("interactive"), + STREAMING("streaming"), + BATCH("batch"); + + private String text; + + JobType(String text) { + this.text = text; + } + + public String getText() { + return this.text; + } + + /** + * Get JobType from text. + * + * @param text text. + * @return JobType {@link JobType}. + */ + public static JobType fromString(String text) { + for (JobType JobType : JobType.values()) { + if (JobType.text.equalsIgnoreCase(text)) { + return JobType; + } + } + throw new IllegalArgumentException("No JobType with text " + text + " found"); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index f6b75d49ef..4816f1c2cd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -52,7 +52,7 @@ public static IndexDetails extractIndexDetails(String sqlQuery) { flintSparkSqlExtensionsParser.statement(); FlintSQLIndexDetailsVisitor flintSQLIndexDetailsVisitor = new FlintSQLIndexDetailsVisitor(); statementContext.accept(flintSQLIndexDetailsVisitor); - return flintSQLIndexDetailsVisitor.getIndexDetails(); + return flintSQLIndexDetailsVisitor.getIndexDetailsBuilder().build(); } public static boolean isIndexQuery(String sqlQuery) { @@ -117,29 +117,29 @@ public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { - @Getter private final IndexDetails indexDetails; + @Getter private final IndexDetails.IndexDetailsBuilder indexDetailsBuilder; public FlintSQLIndexDetailsVisitor() { - this.indexDetails = new IndexDetails(); + this.indexDetailsBuilder = new IndexDetails.IndexDetailsBuilder(); } @Override public Void visitIndexName(FlintSparkSqlExtensionsParser.IndexNameContext ctx) { - indexDetails.setIndexName(ctx.getText()); + indexDetailsBuilder.indexName(ctx.getText()); return super.visitIndexName(ctx); } @Override public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { - indexDetails.setFullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); + indexDetailsBuilder.fullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); return super.visitTableName(ctx); } @Override public Void visitCreateSkippingIndexStatement( FlintSparkSqlExtensionsParser.CreateSkippingIndexStatementContext ctx) { - indexDetails.setDropIndex(false); - indexDetails.setIndexType(FlintIndexType.SKIPPING); + indexDetailsBuilder.isDropIndex(false); + indexDetailsBuilder.indexType(FlintIndexType.SKIPPING); visitPropertyList(ctx.propertyList()); return super.visitCreateSkippingIndexStatement(ctx); } @@ -147,28 +147,47 @@ public Void visitCreateSkippingIndexStatement( @Override public Void visitCreateCoveringIndexStatement( FlintSparkSqlExtensionsParser.CreateCoveringIndexStatementContext ctx) { - indexDetails.setDropIndex(false); - indexDetails.setIndexType(FlintIndexType.COVERING); + indexDetailsBuilder.isDropIndex(false); + indexDetailsBuilder.indexType(FlintIndexType.COVERING); visitPropertyList(ctx.propertyList()); return super.visitCreateCoveringIndexStatement(ctx); } + @Override + public Void visitCreateMaterializedViewStatement( + FlintSparkSqlExtensionsParser.CreateMaterializedViewStatementContext ctx) { + indexDetailsBuilder.isDropIndex(false); + indexDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexDetailsBuilder.mvName(ctx.mvName.getText()); + visitPropertyList(ctx.propertyList()); + return super.visitCreateMaterializedViewStatement(ctx); + } + @Override public Void visitDropCoveringIndexStatement( FlintSparkSqlExtensionsParser.DropCoveringIndexStatementContext ctx) { - indexDetails.setDropIndex(true); - indexDetails.setIndexType(FlintIndexType.COVERING); + indexDetailsBuilder.isDropIndex(true); + indexDetailsBuilder.indexType(FlintIndexType.COVERING); return super.visitDropCoveringIndexStatement(ctx); } @Override public Void visitDropSkippingIndexStatement( FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext ctx) { - indexDetails.setDropIndex(true); - indexDetails.setIndexType(FlintIndexType.SKIPPING); + indexDetailsBuilder.isDropIndex(true); + indexDetailsBuilder.indexType(FlintIndexType.SKIPPING); return super.visitDropSkippingIndexStatement(ctx); } + @Override + public Void visitDropMaterializedViewStatement( + FlintSparkSqlExtensionsParser.DropMaterializedViewStatementContext ctx) { + indexDetailsBuilder.isDropIndex(true); + indexDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitDropMaterializedViewStatement(ctx); + } + @Override public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext ctx) { if (ctx != null) { @@ -180,7 +199,7 @@ public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext // https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala#L35 to unescape string literal if (propertyKey(property.key).toLowerCase(Locale.ROOT).contains("auto_refresh")) { if (propertyValue(property.value).toLowerCase(Locale.ROOT).contains("true")) { - indexDetails.setAutoRefresh(true); + indexDetailsBuilder.autoRefresh(true); } } }); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 4acccae0e2..700acb973e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -67,6 +67,7 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; @@ -124,6 +125,7 @@ void testDispatchSelectQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("job_type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -178,6 +180,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("job_type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -233,6 +236,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("job_type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -365,10 +369,9 @@ void testDispatchSelectQueryFailedCreateSession() { void testDispatchIndexQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); - tags.put("table", "http_logs"); - tags.put("index", "elb_and_requestUri"); + tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("schema", "default"); + tags.put("job_type", JobType.STREAMING.getText()); String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -426,7 +429,7 @@ void testDispatchWithPPLQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - + tags.put("job_type", JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -481,7 +484,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - + tags.put("job_type", JobType.BATCH.getText()); String query = "show tables"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -535,11 +538,9 @@ void testDispatchQueryWithoutATableAndDataSourceName() { void testDispatchIndexQueryWithoutADatasourceName() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); - tags.put("table", "http_logs"); - tags.put("index", "elb_and_requestUri"); + tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("schema", "default"); - + tags.put("job_type", JobType.STREAMING.getText()); String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -592,6 +593,65 @@ void testDispatchIndexQueryWithoutADatasourceName() { verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchMaterializedViewQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("index", "flint_mv_1"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("job_type", JobType.STREAMING.getText()); + String query = + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + @Test void testDispatchWithWrongURI() { when(dataSourceService.getRawDataSourceMetadata("my_glue")) @@ -841,13 +901,15 @@ void testGetQueryResponseOfDropIndex() { @Test void testDropIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP INDEX size_year ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "size_year", - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.COVERING))) + IndexDetails indexDetails = + IndexDetails.builder() + .indexName("size_year") + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.COVERING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); // auto_refresh == true @@ -876,15 +938,7 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - "size_year", - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.COVERING)); - + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -894,13 +948,14 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { @Test void testDropSkippingIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexDetails indexDetails = + IndexDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); when(flintIndexMetadata.isAutoRefresh()).thenReturn(true); @@ -927,14 +982,7 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -945,13 +993,14 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio void testDropSkippingIndexQueryAutoRefreshFalse() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexDetails indexDetails = + IndexDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -972,14 +1021,7 @@ void testDropSkippingIndexQueryAutoRefreshFalse() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -990,13 +1032,14 @@ void testDropSkippingIndexQueryAutoRefreshFalse() void testDropSkippingIndexQueryDeleteIndexException() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - when(flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING))) + IndexDetails indexDetails = + IndexDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.SKIPPING) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -1018,14 +1061,7 @@ void testDropSkippingIndexQueryDeleteIndexException() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)) - .getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("my_glue.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.FAILED.toString(), dropIndexResult.getStatus()); @@ -1035,6 +1071,52 @@ void testDropSkippingIndexQueryDeleteIndexException() Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); } + @Test + void testDropMVQuery() throws ExecutionException, InterruptedException { + String query = "DROP MATERIALIZED VIEW mv_1"; + IndexDetails indexDetails = + IndexDetails.builder() + .mvName("mv_1") + .isDropIndex(true) + .fullyQualifiedTableName(null) + .indexType(FlintIndexType.MATERIALIZED_VIEW) + .build(); + when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + .thenReturn(flintIndexMetadata); + when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); + // auto_refresh == true + when(flintIndexMetadata.isAutoRefresh()).thenReturn(true); + + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + + AcknowledgedResponse acknowledgedResponse = mock(AcknowledgedResponse.class); + when(openSearchClient.admin().indices().delete(any()).get()).thenReturn(acknowledgedResponse); + when(acknowledgedResponse.isAcknowledged()).thenReturn(true); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + SparkQueryDispatcher.DropIndexResult dropIndexResult = + SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); + Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); + } + @Test void testDispatchQueryWithExtraSparkSubmitParameters() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java index b0c8491b0b..3cc40e0df5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java @@ -44,12 +44,12 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - null, - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING)); + IndexDetails.builder() + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.SKIPPING) + .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); } @@ -64,12 +64,13 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.COVERING)); + IndexDetails.builder() + .indexName("cv1") + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.COVERING) + .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); } @@ -86,34 +87,17 @@ void testGetJobIDWithNPEException() { IllegalArgumentException.class, () -> flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.COVERING))); + IndexDetails.builder() + .indexName("cv1") + .fullyQualifiedTableName( + new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.COVERING) + .build())); Assertions.assertEquals("Provided Index doesn't exist", illegalArgumentException.getMessage()); } - @SneakyThrows - @Test - void testGetJobIdFromUnsupportedIndex() { - FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - flintIndexMetadataReader.getFlintIndexMetadata( - new IndexDetails( - "cv1", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.MATERIALIZED_VIEW))); - Assertions.assertEquals( - "Unsupported Index Type : MATERIALIZED_VIEW", unsupportedOperationException.getMessage()); - } - @SneakyThrows public void mockNodeClientIndicesMappings(String indexName, String mappings) { GetMappingsResponse mockResponse = mock(GetMappingsResponse.class); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java index 46fa4f7dbe..cf6b5f8f2b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java @@ -16,12 +16,13 @@ public class IndexDetailsTest { public void skippingIndexName() { assertEquals( "flint_mys3_default_http_logs_skipping_index", - new IndexDetails( - "invalid", - new FullyQualifiedTableName("mys3.default.http_logs"), - false, - true, - FlintIndexType.SKIPPING) + IndexDetails.builder() + .indexName("invalid") + .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) + .autoRefresh(false) + .isDropIndex(true) + .indexType(FlintIndexType.SKIPPING) + .build() .openSearchIndexName()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index af892fa097..01759c2bdd 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.utils; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; +import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; import lombok.Getter; @@ -112,50 +113,67 @@ void testExtractionFromFlintIndexQueries() { Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); } + @Test + void testExtractionFromFlintMVQuery() { + String createCoveredIndexQuery = + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + + " (auto_refresh = true)"; + Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); + IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals("mv_1", indexDetails.getMvName()); + } + /** https://github.com/opensearch-project/sql/issues/2206 */ @Test void testAutoRefresh() { Assertions.assertFalse( - SQLQueryUtils.extractIndexDetails(skippingIndex().getQuery()).getAutoRefresh()); + SQLQueryUtils.extractIndexDetails(skippingIndex().getQuery()).isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "false").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "\"true\"").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "1").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails(skippingIndex().withProperty("interval", "1").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); - Assertions.assertFalse(SQLQueryUtils.extractIndexDetails(index().getQuery()).getAutoRefresh()); + Assertions.assertFalse(SQLQueryUtils.extractIndexDetails(index().getQuery()).isAutoRefresh()); Assertions.assertFalse( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "false").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); Assertions.assertTrue( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "true").getQuery()) - .getAutoRefresh()); + .isAutoRefresh()); + + Assertions.assertTrue( + SQLQueryUtils.extractIndexDetails(mv().withProperty("auto_refresh", "true").getQuery()) + .isAutoRefresh()); } @Getter @@ -176,6 +194,11 @@ public static IndexQuery index() { "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, " + "l_quantity)"); } + public static IndexQuery mv() { + return new IndexQuery( + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs"); + } + public IndexQuery withProperty(String key, String value) { query = String.format("%s with (%s = %s)", query, key, value); return this; From a27e733ae7e0bbfd6a18a93cedd46d89c2c51904 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 24 Oct 2023 07:55:07 -0700 Subject: [PATCH 12/16] Fix bug, using basic instead of basicauth (#2342) * Fix bug, using basic instead of basicauth Signed-off-by: Peng Huo * fix codestyle Signed-off-by: Peng Huo * fix IT failure: datasourceWithBasicAuth Signed-off-by: Peng Huo * fix UT Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../model/SparkSubmitParameters.java | 3 +- .../dispatcher/InteractiveQueryHandler.java | 2 +- .../session/CreateSessionRequest.java | 30 ++++- ...AsyncQueryExecutorServiceImplSpecTest.java | 113 +++++++++++++++++- .../dispatcher/SparkQueryDispatcherTest.java | 6 +- 5 files changed, 146 insertions(+), 8 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index db78abb2a8..9a73b0f364 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -31,6 +31,7 @@ public class SparkSubmitParameters { public static final String SPACE = " "; public static final String EQUALS = "="; + public static final String FLINT_BASIC_AUTH = "basic"; private final String className; private final Map config; @@ -114,7 +115,7 @@ private void setFlintIndexStoreAuthProperties( Supplier password, Supplier region) { if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); + config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 24ea1528c8..52cc2efbe2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -39,7 +39,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); StatementState statementState = statement.getStatementState(); result.put(STATUS_FIELD, statementState.getState()); - result.put(ERROR_FIELD, ""); + result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse("")); return result; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index ca2b2b4867..b2201fbd01 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -21,14 +21,40 @@ public class CreateSessionRequest { private final String datasourceName; public StartJobRequest getStartJobRequest() { - return new StartJobRequest( + return new InteractiveSessionStartJobRequest( "select 1", jobName, applicationId, executionRoleArn, sparkSubmitParametersBuilder.build().toString(), tags, - false, resultIndex); } + + static class InteractiveSessionStartJobRequest extends StartJobRequest { + public InteractiveSessionStartJobRequest( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams, + Map tags, + String resultIndex) { + super( + query, + jobName, + applicationId, + executionRoleArn, + sparkSubmitParams, + tags, + false, + resultIndex); + } + + /** Interactive query keep running. */ + @Override + public Long executionTimeout() { + return 0L; + } + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 1ee119df78..19edd53eae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -17,6 +17,7 @@ import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -26,7 +27,9 @@ import com.google.common.collect.ImmutableSet; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import lombok.Getter; import org.junit.After; @@ -109,7 +112,7 @@ public void setup() { "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", "glue.indexstore.opensearch.uri", - "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", + "http://localhost:9200", "glue.indexstore.opensearch.auth", "noauth"), null)); @@ -269,8 +272,114 @@ public void reuseSessionWhenCreateAsyncQuery() { assertEquals(second.getQueryId(), secondModel.get().getQueryId()); } + @Test + public void batchQueryHasTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + + assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void interactiveQueryNoTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void datasourceWithBasicAuth() { + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "basicauth"); + properties.put("glue.indexstore.opensearch.auth.username", "username"); + properties.put("glue.indexstore.opensearch.auth.password", "password"); + + dataSourceService.createDataSource( + new DataSourceMetadata( + "mybasicauth", DataSourceType.S3GLUE, ImmutableList.of(), properties, null)); + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.username=username"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.password=password"))); + } + + @Test + public void withSessionCreateAsyncQueryFailed() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("myselect 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. not result write to SPARK_RESPONSE_BUFFER_INDEX_NAME yet. + // mock failed statement. + StatementModel submitted = statementModel.get(); + StatementModel mocked = + StatementModel.builder() + .version("1.0") + .statementState(submitted.getStatementState()) + .statementId(submitted.getStatementId()) + .sessionId(submitted.getSessionId()) + .applicationId(submitted.getApplicationId()) + .jobId(submitted.getJobId()) + .langType(submitted.getLangType()) + .datasourceName(submitted.getDatasourceName()) + .query(submitted.getQuery()) + .queryId(submitted.getQueryId()) + .submitTime(submitted.getSubmitTime()) + .error("mock error") + .seqNo(submitted.getSeqNo()) + .primaryTerm(submitted.getPrimaryTerm()) + .build(); + updateStatementState(stateStore, DATASOURCE).apply(mocked, StatementState.FAILED); + + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus()); + assertEquals("mock error", asyncQueryResults.getError()); + } + private DataSourceServiceImpl createDataSourceService() { - String masterKey = "1234567890"; + String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = new OpenSearchDataSourceMetadataStorage( client, clusterService, new EncryptorImpl(masterKey)); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 700acb973e..a69c6e2b1a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -99,7 +99,8 @@ public class SparkQueryDispatcherTest { @Mock(answer = RETURNS_DEEP_STUBS) private Session session; - @Mock private Statement statement; + @Mock(answer = RETURNS_DEEP_STUBS) + private Statement statement; private SparkQueryDispatcher sparkQueryDispatcher; @@ -184,7 +185,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( - "basicauth", + "basic", new HashMap<>() { { put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); @@ -783,6 +784,7 @@ void testGetQueryResponse() { void testGetQueryResponseWithSession() { doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); doReturn(Optional.of(statement)).when(session).get(any()); + when(statement.getStatementModel().getError()).thenReturn("mock error"); doReturn(StatementState.WAITING).when(statement).getStatementState(); doReturn(new JSONObject()) From 4d44c091b272e67e8d2ef882709e876fa3476813 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 24 Oct 2023 14:21:15 -0700 Subject: [PATCH 13/16] Bug Fix, support cancel query in running state (#2351) Signed-off-by: Peng Huo --- .../spark/execution/statement/Statement.java | 10 +- .../execution/statement/StatementTest.java | 112 +++++++++++++++--- 2 files changed, 104 insertions(+), 18 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index d84c91bdb8..94c1f79511 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -62,9 +62,15 @@ public void open() { /** Cancel a statement. */ public void cancel() { - if (statementModel.getStatementState().equals(StatementState.RUNNING)) { + StatementState statementState = statementModel.getStatementState(); + + if (statementState.equals(StatementState.SUCCESS) + || statementState.equals(StatementState.FAILED) + || statementState.equals(StatementState.CANCELLED)) { String errorMsg = - String.format("can't cancel statement in waiting state. statement: %s.", statementId); + String.format( + "can't cancel statement in %s state. statement: %s.", + statementState.getState(), statementId); LOG.error(errorMsg); throw new IllegalStateException(errorMsg); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 1e33c8a6b9..29020f2496 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.RUNNING; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; @@ -168,38 +169,93 @@ public void cancelFailedBecauseOfConflict() { } @Test - public void cancelRunningStatementFailed() { + public void cancelSuccessStatementFailed() { StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.SUCCESS, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in success state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); // update to running state StatementModel model = st.getStatementModel(); st.setStatementModel( StatementModel.copyWithState( st.getStatementModel(), - StatementState.RUNNING, + StatementState.FAILED, model.getSeqNo(), model.getPrimaryTerm())); // cancel conflict IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( - String.format("can't cancel statement in waiting state. statement: %s.", stId), + String.format("can't cancel statement in failed state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelCancelledStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), CANCELLED, model.getSeqNo(), model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in cancelled state. statement: %s.", stId), exception.getMessage()); } + @Test + public void cancelRunningStatementSuccess() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + testStatement.run(); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + @Test public void submitStatementInRunningSession() { Session session = @@ -355,9 +411,33 @@ public TestStatement cancel() { st.cancel(); return this; } + + public TestStatement run() { + StatementModel model = + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), RUNNING); + st.setStatementModel(model); + return this; + } } private QueryRequest queryRequest() { return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); } + + private Statement createStatement(StatementId stId) { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + return st; + } } From a2014eed8de1114d264667d1b812e8cf3b673971 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 25 Oct 2023 09:15:39 -0700 Subject: [PATCH 14/16] Add Session limitation (#2354) * add session creation limitation Signed-off-by: Peng Huo * add doc Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../sql/common/setting/Settings.java | 3 +- docs/user/admin/settings.rst | 36 ++++++++++++ .../setting/OpenSearchSettings.java | 14 +++++ .../execution/session/SessionManager.java | 16 +++++ .../execution/statestore/StateStore.java | 51 ++++++++++++++++ .../query_execution_request_mapping.yml | 2 + ...AsyncQueryExecutorServiceImplSpecTest.java | 58 +++++++++++++++++++ .../execution/session/SessionManagerTest.java | 1 + 8 files changed, 180 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 89d046b3d9..ae1950d81c 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -39,7 +39,8 @@ public enum Key { METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), CLUSTER_NAME("cluster.name"), - SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"); + SPARK_EXECUTION_SESSION_ENABLED("plugins.query.executionengine.spark.session.enabled"), + SPARK_EXECUTION_SESSION_LIMIT("plugins.query.executionengine.spark.session.limit"); @Getter private final String keyValue; diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index cd56e76491..686116636a 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -347,3 +347,39 @@ SQL query:: } } +plugins.query.executionengine.spark.session.limit +=================================================== + +Description +----------- + +Each datasource can have maximum 100 sessions running in parallel by default. You can increase limit by this setting. + +1. The default value is 100. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.executionengine.spark.session.limit":200}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "spark": { + "session": { + "limit": "200" + } + } + } + } + } + } + } + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index ecb35afafa..f80b576fe0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -142,6 +142,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_SESSION_LIMIT_SETTING = + Setting.intSetting( + Key.SPARK_EXECUTION_SESSION_LIMIT.getKeyValue(), + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -218,6 +225,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.SPARK_EXECUTION_SESSION_ENABLED, SPARK_EXECUTION_SESSION_ENABLED_SETTING, new Updater(Key.SPARK_EXECUTION_SESSION_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_SESSION_LIMIT, + SPARK_EXECUTION_SESSION_LIMIT_SETTING, + new Updater(Key.SPARK_EXECUTION_SESSION_LIMIT)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -284,6 +297,7 @@ public static List> pluginSettings() { .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(SPARK_EXECUTION_ENGINE_CONFIG) .add(SPARK_EXECUTION_SESSION_ENABLED_SETTING) + .add(SPARK_EXECUTION_SESSION_LIMIT_SETTING) .build(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c0f7bbcde8..81b9fdaee0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -6,8 +6,11 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_LIMIT; import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.activeSessionsCount; +import java.util.Locale; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.setting.Settings; @@ -26,6 +29,15 @@ public class SessionManager { private final Settings settings; public Session createSession(CreateSessionRequest request) { + int sessionMaxLimit = sessionMaxLimit(); + if (activeSessionsCount(stateStore, request.getDatasourceName()).get() >= sessionMaxLimit) { + String errorMsg = + String.format( + Locale.ROOT, + "The maximum number of active sessions can be " + "supported is %d", + sessionMaxLimit); + throw new IllegalArgumentException(errorMsg); + } InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) @@ -55,4 +67,8 @@ public Optional getSession(SessionId sid) { public boolean isEnabled() { return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); } + + public int sessionMaxLimit() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_LIMIT); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 6546d303fb..e6bad9fc26 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -14,6 +14,7 @@ import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import lombok.RequiredArgsConstructor; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; @@ -25,6 +26,8 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; @@ -38,9 +41,13 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.session.SessionType; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; @@ -182,6 +189,35 @@ private void createIndex(String indexName) { } } + private long count(String indexName, QueryBuilder query) { + if (!this.clusterService.state().routingTable().hasIndex(indexName)) { + return 0; + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(0); + + // https://github.com/opensearch-project/sql/issues/1801. + SearchRequest searchRequest = + new SearchRequest() + .indices(indexName) + .preference("_primary_first") + .source(searchSourceBuilder); + + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching job metadata information failed with status : " + searchResponse.status()); + } else { + return searchResponse.getHits().getTotalHits().value; + } + } + private String loadConfigFromResource(String fileName) throws IOException { InputStream fileStream = StateStore.class.getClassLoader().getResourceAsStream(fileName); return IOUtils.toString(fileStream, StandardCharsets.UTF_8); @@ -253,4 +289,19 @@ public static Function> getJobMetaData( AsyncQueryJobMetadata::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } + + public static Supplier activeSessionsCount(StateStore stateStore, String datasourceName) { + return () -> + stateStore.count( + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName), + QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery(SessionModel.TYPE, SessionModel.SESSION_DOC_TYPE)) + .must( + QueryBuilders.termQuery( + SessionModel.SESSION_TYPE, SessionType.INTERACTIVE.getSessionType())) + .must(QueryBuilders.termQuery(SessionModel.DATASOURCE_NAME, datasourceName)) + .must( + QueryBuilders.termQuery( + SessionModel.SESSION_STATE, SessionState.RUNNING.getSessionState()))); + } } diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml index fbe90a1cba..682534d338 100644 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -40,3 +40,5 @@ properties: format: strict_date_time||epoch_millis queryId: type: keyword + excludeJobIds: + type: keyword diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 19edd53eae..f65049a7d9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.asyncquery; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_ENABLED_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_LIMIT_SETTING; import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; @@ -16,7 +17,9 @@ import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; @@ -61,6 +64,8 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -129,6 +134,13 @@ public void clean() { .setTransientSettings( Settings.builder().putNull(SPARK_EXECUTION_SESSION_ENABLED_SETTING.getKey()).build()) .get(); + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().putNull(SPARK_EXECUTION_SESSION_LIMIT_SETTING.getKey()).build()) + .get(); } @Test @@ -378,6 +390,35 @@ public void withSessionCreateAsyncQueryFailed() { assertEquals("mock error", asyncQueryResults.getError()); } + @Test + public void createSessionMoreThanLimitFailed() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + // only allow one session in domain. + setSessionLimit(1); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + setSessionState(first.getSessionId(), SessionState.RUNNING); + + // 2. create async query without session. + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null))); + assertEquals( + "The maximum number of active sessions can be supported is 1", exception.getMessage()); + } + private DataSourceServiceImpl createDataSourceService() { String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = @@ -470,6 +511,16 @@ public void enableSession(boolean enabled) { .get(); } + public void setSessionLimit(long limit) { + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder().put(SPARK_EXECUTION_SESSION_LIMIT_SETTING.getKey(), limit).build()) + .get(); + } + int search(QueryBuilder query) { SearchRequest searchRequest = new SearchRequest(); searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(DATASOURCE)); @@ -480,4 +531,11 @@ int search(QueryBuilder query) { return searchResponse.getHits().getHits().length; } + + void setSessionState(String sessionId, SessionState sessionState) { + Optional model = getSession(stateStore, DATASOURCE).apply(sessionId); + SessionModel updated = + updateSessionState(stateStore, DATASOURCE).apply(model.get(), sessionState); + assertEquals(SessionState.RUNNING, updated.getSessionState()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 4374bd4f11..3546a874d9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -33,6 +33,7 @@ public void sessionEnable() { public static Settings sessionSetting(boolean enabled) { Map settings = new HashMap<>(); settings.put(Settings.Key.SPARK_EXECUTION_SESSION_ENABLED, enabled); + settings.put(Settings.Key.SPARK_EXECUTION_SESSION_LIMIT, 100); return settings(settings); } From 886c2fcc87c461304a9457653d25255b985c9837 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 25 Oct 2023 09:24:24 -0700 Subject: [PATCH 15/16] Handle Describe,Refresh and Show Queries Properly (#2357) Signed-off-by: Vamsi Manohar --- .../src/main/antlr/FlintSparkSqlExtensions.g4 | 5 + .../dispatcher/SparkQueryDispatcher.java | 84 +++++-- .../model/IndexQueryActionType.java | 15 ++ ...dexDetails.java => IndexQueryDetails.java} | 65 ++--- .../execution/session/InteractiveSession.java | 3 + .../spark/flint/FlintIndexMetadataReader.java | 6 +- .../flint/FlintIndexMetadataReaderImpl.java | 6 +- .../sql/spark/utils/SQLQueryUtils.java | 113 +++++++-- .../dispatcher/SparkQueryDispatcherTest.java | 234 +++++++++++++++--- .../session/InteractiveSessionTest.java | 3 +- .../FlintIndexMetadataReaderImplTest.java | 15 +- ...lsTest.java => IndexQueryDetailsTest.java} | 9 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 123 +++++++-- 13 files changed, 519 insertions(+), 162 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java rename spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/{IndexDetails.java => IndexQueryDetails.java} (55%) rename spark/src/test/java/org/opensearch/sql/spark/flint/{IndexDetailsTest.java => IndexQueryDetailsTest.java} (71%) diff --git a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 index c4af2779d1..f48c276e44 100644 --- a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 +++ b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -79,6 +79,7 @@ dropCoveringIndexStatement materializedViewStatement : createMaterializedViewStatement + | refreshMaterializedViewStatement | showMaterializedViewStatement | describeMaterializedViewStatement | dropMaterializedViewStatement @@ -90,6 +91,10 @@ createMaterializedViewStatement (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; +refreshMaterializedViewStatement + : REFRESH MATERIALIZED VIEW mvName=multipartIdentifier + ; + showMaterializedViewStatement : SHOW MATERIALIZED (VIEW | VIEWS) IN catalogDb=multipartIdentifier ; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index ff7ccf8c08..6ec67709b8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -36,10 +36,7 @@ import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; -import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.dispatcher.model.*; import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -56,11 +53,10 @@ public class SparkQueryDispatcher { private static final Logger LOG = LogManager.getLogger(); - public static final String INDEX_TAG_KEY = "index"; public static final String DATASOURCE_TAG_KEY = "datasource"; public static final String CLUSTER_NAME_TAG_KEY = "cluster"; - public static final String JOB_TYPE_TAG_KEY = "job_type"; + public static final String JOB_TYPE_TAG_KEY = "type"; private EMRServerlessClient emrServerlessClient; @@ -107,15 +103,18 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { - if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) { - IndexDetails indexDetails = + if (SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - fillMissingDetails(dispatchQueryRequest, indexDetails); + fillMissingDetails(dispatchQueryRequest, indexQueryDetails); - if (indexDetails.isDropIndex()) { - return handleDropIndexQuery(dispatchQueryRequest, indexDetails); + // TODO: refactor this code properly. + if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { + return handleDropIndexQuery(dispatchQueryRequest, indexQueryDetails); + } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { + return handleStreamingQueries(dispatchQueryRequest, indexQueryDetails); } else { - return handleIndexQuery(dispatchQueryRequest, indexDetails); + return handleFlintNonStreamingQueries(dispatchQueryRequest, indexQueryDetails); } } else { return handleNonIndexQuery(dispatchQueryRequest); @@ -127,24 +126,59 @@ private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryR // Spark Assumes the datasource to be catalog. // This is required to handle drop index case properly when datasource name is not provided. private static void fillMissingDetails( - DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { - if (indexDetails.getFullyQualifiedTableName() != null - && indexDetails.getFullyQualifiedTableName().getDatasourceName() == null) { - indexDetails + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { + if (indexQueryDetails.getFullyQualifiedTableName() != null + && indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) { + indexQueryDetails .getFullyQualifiedTableName() .setDatasourceName(dispatchQueryRequest.getDatasource()); } } - private DispatchQueryResponse handleIndexQuery( - DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + private DispatchQueryResponse handleStreamingQueries( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); + String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; + Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); + tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName()); + if (indexQueryDetails.isAutoRefresh()) { + tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); + } + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource( + dataSourceService.getRawDataSourceMetadata( + dispatchQueryRequest.getDatasource())) + .structuredStreaming(indexQueryDetails.isAutoRefresh()) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + indexQueryDetails.isAutoRefresh(), + dataSourceMetadata.getResultIndex()); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); + } + + private DispatchQueryResponse handleFlintNonStreamingQueries( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - tags.put(INDEX_TAG_KEY, indexDetails.openSearchIndexName()); - tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), @@ -155,12 +189,11 @@ private DispatchQueryResponse handleIndexQuery( .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) - .structuredStreaming(indexDetails.isAutoRefresh()) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), tags, - indexDetails.isAutoRefresh(), + indexQueryDetails.isAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); return new DispatchQueryResponse( @@ -242,11 +275,12 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ } private DispatchQueryResponse handleDropIndexQuery( - DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata(indexDetails); + FlintIndexMetadata indexMetadata = + flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails); // if index is created without auto refresh. there is no job to cancel. String status = JobRunState.FAILED.toString(); try { @@ -255,7 +289,7 @@ private DispatchQueryResponse handleDropIndexQuery( dispatchQueryRequest.getApplicationId(), indexMetadata.getJobId()); } } finally { - String indexName = indexDetails.openSearchIndexName(); + String indexName = indexQueryDetails.openSearchIndexName(); try { AcknowledgedResponse response = client.admin().indices().delete(new DeleteIndexRequest().indices(indexName)).get(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java new file mode 100644 index 0000000000..2c96511d2a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryActionType.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +/** Enum for Index Action in the given query.* */ +public enum IndexQueryActionType { + CREATE, + REFRESH, + DESCRIBE, + SHOW, + DROP +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java similarity index 55% rename from spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java rename to spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java index 42e2905e67..5b4326a10e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexQueryDetails.java @@ -5,7 +5,6 @@ package org.opensearch.sql.spark.dispatcher.model; -import com.google.common.base.Preconditions; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.commons.lang3.StringUtils; @@ -14,7 +13,7 @@ /** Index details in an async query. */ @Getter @EqualsAndHashCode -public class IndexDetails { +public class IndexQueryDetails { public static final String STRIP_CHARS = "`"; @@ -22,75 +21,59 @@ public class IndexDetails { private FullyQualifiedTableName fullyQualifiedTableName; // by default, auto_refresh = false; private boolean autoRefresh; - private boolean isDropIndex; + private IndexQueryActionType indexQueryActionType; // materialized view special case where // table name and mv name are combined. private String mvName; private FlintIndexType indexType; - private IndexDetails() {} + private IndexQueryDetails() {} - public static IndexDetailsBuilder builder() { - return new IndexDetailsBuilder(); + public static IndexQueryDetailsBuilder builder() { + return new IndexQueryDetailsBuilder(); } // Builder class - public static class IndexDetailsBuilder { - private final IndexDetails indexDetails; + public static class IndexQueryDetailsBuilder { + private final IndexQueryDetails indexQueryDetails; - public IndexDetailsBuilder() { - indexDetails = new IndexDetails(); + public IndexQueryDetailsBuilder() { + indexQueryDetails = new IndexQueryDetails(); } - public IndexDetailsBuilder indexName(String indexName) { - indexDetails.indexName = indexName; + public IndexQueryDetailsBuilder indexName(String indexName) { + indexQueryDetails.indexName = indexName; return this; } - public IndexDetailsBuilder fullyQualifiedTableName(FullyQualifiedTableName tableName) { - indexDetails.fullyQualifiedTableName = tableName; + public IndexQueryDetailsBuilder fullyQualifiedTableName(FullyQualifiedTableName tableName) { + indexQueryDetails.fullyQualifiedTableName = tableName; return this; } - public IndexDetailsBuilder autoRefresh(Boolean autoRefresh) { - indexDetails.autoRefresh = autoRefresh; + public IndexQueryDetailsBuilder autoRefresh(Boolean autoRefresh) { + indexQueryDetails.autoRefresh = autoRefresh; return this; } - public IndexDetailsBuilder isDropIndex(boolean isDropIndex) { - indexDetails.isDropIndex = isDropIndex; + public IndexQueryDetailsBuilder indexQueryActionType( + IndexQueryActionType indexQueryActionType) { + indexQueryDetails.indexQueryActionType = indexQueryActionType; return this; } - public IndexDetailsBuilder mvName(String mvName) { - indexDetails.mvName = mvName; + public IndexQueryDetailsBuilder mvName(String mvName) { + indexQueryDetails.mvName = mvName; return this; } - public IndexDetailsBuilder indexType(FlintIndexType indexType) { - indexDetails.indexType = indexType; + public IndexQueryDetailsBuilder indexType(FlintIndexType indexType) { + indexQueryDetails.indexType = indexType; return this; } - public IndexDetails build() { - Preconditions.checkNotNull(indexDetails.indexType, "Index Type can't be null"); - switch (indexDetails.indexType) { - case COVERING: - Preconditions.checkNotNull( - indexDetails.indexName, "IndexName can't be null for Covering Index."); - Preconditions.checkNotNull( - indexDetails.fullyQualifiedTableName, "TableName can't be null for Covering Index."); - break; - case SKIPPING: - Preconditions.checkNotNull( - indexDetails.fullyQualifiedTableName, "TableName can't be null for Skipping Index."); - break; - case MATERIALIZED_VIEW: - Preconditions.checkNotNull(indexDetails.mvName, "Materialized view name can't be null"); - break; - } - - return indexDetails; + public IndexQueryDetails build() { + return indexQueryDetails; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index a2e7cfe6ee..956275b04a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -34,6 +34,8 @@ public class InteractiveSession implements Session { private static final Logger LOG = LogManager.getLogger(); + public static final String SESSION_ID_TAG_KEY = "sid"; + private final SessionId sessionId; private final StateStore stateStore; private final EMRServerlessClient serverlessClient; @@ -46,6 +48,7 @@ public void open(CreateSessionRequest createSessionRequest) { createSessionRequest .getSparkSubmitParametersBuilder() .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId()); String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java index e4a5e92035..d4a8e7ddbf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java @@ -1,6 +1,6 @@ package org.opensearch.sql.spark.flint; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; /** Interface for FlintIndexMetadataReader */ public interface FlintIndexMetadataReader { @@ -8,8 +8,8 @@ public interface FlintIndexMetadataReader { /** * Given Index details, get the streaming job Id. * - * @param indexDetails indexDetails. + * @param indexQueryDetails indexDetails. * @return FlintIndexMetadata. */ - FlintIndexMetadata getFlintIndexMetadata(IndexDetails indexDetails); + FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexQueryDetails); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java index 5f712e65cd..a16d0b9138 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java @@ -5,7 +5,7 @@ import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; /** Implementation of {@link FlintIndexMetadataReader} */ @AllArgsConstructor @@ -14,8 +14,8 @@ public class FlintIndexMetadataReaderImpl implements FlintIndexMetadataReader { private final Client client; @Override - public FlintIndexMetadata getFlintIndexMetadata(IndexDetails indexDetails) { - String indexName = indexDetails.openSearchIndexName(); + public FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexQueryDetails) { + String indexName = indexQueryDetails.openSearchIndexName(); GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings(indexName).get(); try { diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 4816f1c2cd..c1f3f02576 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -20,7 +20,8 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.flint.FlintIndexType; /** @@ -42,7 +43,7 @@ public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQ return sparkSqlTableNameVisitor.getFullyQualifiedTableName(); } - public static IndexDetails extractIndexDetails(String sqlQuery) { + public static IndexQueryDetails extractIndexDetails(String sqlQuery) { FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = new FlintSparkSqlExtensionsParser( new CommonTokenStream( @@ -52,10 +53,10 @@ public static IndexDetails extractIndexDetails(String sqlQuery) { flintSparkSqlExtensionsParser.statement(); FlintSQLIndexDetailsVisitor flintSQLIndexDetailsVisitor = new FlintSQLIndexDetailsVisitor(); statementContext.accept(flintSQLIndexDetailsVisitor); - return flintSQLIndexDetailsVisitor.getIndexDetailsBuilder().build(); + return flintSQLIndexDetailsVisitor.getIndexQueryDetailsBuilder().build(); } - public static boolean isIndexQuery(String sqlQuery) { + public static boolean isFlintExtensionQuery(String sqlQuery) { FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = new FlintSparkSqlExtensionsParser( new CommonTokenStream( @@ -117,29 +118,29 @@ public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { - @Getter private final IndexDetails.IndexDetailsBuilder indexDetailsBuilder; + @Getter private final IndexQueryDetails.IndexQueryDetailsBuilder indexQueryDetailsBuilder; public FlintSQLIndexDetailsVisitor() { - this.indexDetailsBuilder = new IndexDetails.IndexDetailsBuilder(); + this.indexQueryDetailsBuilder = new IndexQueryDetails.IndexQueryDetailsBuilder(); } @Override public Void visitIndexName(FlintSparkSqlExtensionsParser.IndexNameContext ctx) { - indexDetailsBuilder.indexName(ctx.getText()); + indexQueryDetailsBuilder.indexName(ctx.getText()); return super.visitIndexName(ctx); } @Override public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { - indexDetailsBuilder.fullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); + indexQueryDetailsBuilder.fullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); return super.visitTableName(ctx); } @Override public Void visitCreateSkippingIndexStatement( FlintSparkSqlExtensionsParser.CreateSkippingIndexStatementContext ctx) { - indexDetailsBuilder.isDropIndex(false); - indexDetailsBuilder.indexType(FlintIndexType.SKIPPING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); visitPropertyList(ctx.propertyList()); return super.visitCreateSkippingIndexStatement(ctx); } @@ -147,8 +148,8 @@ public Void visitCreateSkippingIndexStatement( @Override public Void visitCreateCoveringIndexStatement( FlintSparkSqlExtensionsParser.CreateCoveringIndexStatementContext ctx) { - indexDetailsBuilder.isDropIndex(false); - indexDetailsBuilder.indexType(FlintIndexType.COVERING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); visitPropertyList(ctx.propertyList()); return super.visitCreateCoveringIndexStatement(ctx); } @@ -156,9 +157,9 @@ public Void visitCreateCoveringIndexStatement( @Override public Void visitCreateMaterializedViewStatement( FlintSparkSqlExtensionsParser.CreateMaterializedViewStatementContext ctx) { - indexDetailsBuilder.isDropIndex(false); - indexDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); - indexDetailsBuilder.mvName(ctx.mvName.getText()); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.CREATE); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); visitPropertyList(ctx.propertyList()); return super.visitCreateMaterializedViewStatement(ctx); } @@ -166,28 +167,94 @@ public Void visitCreateMaterializedViewStatement( @Override public Void visitDropCoveringIndexStatement( FlintSparkSqlExtensionsParser.DropCoveringIndexStatementContext ctx) { - indexDetailsBuilder.isDropIndex(true); - indexDetailsBuilder.indexType(FlintIndexType.COVERING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); return super.visitDropCoveringIndexStatement(ctx); } @Override public Void visitDropSkippingIndexStatement( FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext ctx) { - indexDetailsBuilder.isDropIndex(true); - indexDetailsBuilder.indexType(FlintIndexType.SKIPPING); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); return super.visitDropSkippingIndexStatement(ctx); } @Override public Void visitDropMaterializedViewStatement( FlintSparkSqlExtensionsParser.DropMaterializedViewStatementContext ctx) { - indexDetailsBuilder.isDropIndex(true); - indexDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); - indexDetailsBuilder.mvName(ctx.mvName.getText()); + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DROP); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); return super.visitDropMaterializedViewStatement(ctx); } + @Override + public Void visitDescribeCoveringIndexStatement( + FlintSparkSqlExtensionsParser.DescribeCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitDescribeCoveringIndexStatement(ctx); + } + + @Override + public Void visitDescribeSkippingIndexStatement( + FlintSparkSqlExtensionsParser.DescribeSkippingIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); + return super.visitDescribeSkippingIndexStatement(ctx); + } + + @Override + public Void visitDescribeMaterializedViewStatement( + FlintSparkSqlExtensionsParser.DescribeMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.DESCRIBE); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitDescribeMaterializedViewStatement(ctx); + } + + @Override + public Void visitShowCoveringIndexStatement( + FlintSparkSqlExtensionsParser.ShowCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.SHOW); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitShowCoveringIndexStatement(ctx); + } + + @Override + public Void visitShowMaterializedViewStatement( + FlintSparkSqlExtensionsParser.ShowMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.SHOW); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + return super.visitShowMaterializedViewStatement(ctx); + } + + @Override + public Void visitRefreshCoveringIndexStatement( + FlintSparkSqlExtensionsParser.RefreshCoveringIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); + return super.visitRefreshCoveringIndexStatement(ctx); + } + + @Override + public Void visitRefreshSkippingIndexStatement( + FlintSparkSqlExtensionsParser.RefreshSkippingIndexStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); + return super.visitRefreshSkippingIndexStatement(ctx); + } + + @Override + public Void visitRefreshMaterializedViewStatement( + FlintSparkSqlExtensionsParser.RefreshMaterializedViewStatementContext ctx) { + indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.REFRESH); + indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); + indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); + return super.visitRefreshMaterializedViewStatement(ctx); + } + @Override public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext ctx) { if (ctx != null) { @@ -199,7 +266,7 @@ public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext // https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala#L35 to unescape string literal if (propertyKey(property.key).toLowerCase(Locale.ROOT).contains("auto_refresh")) { if (propertyValue(property.value).toLowerCase(Locale.ROOT).contains("true")) { - indexDetailsBuilder.autoRefresh(true); + indexQueryDetailsBuilder.autoRefresh(true); } } }); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a69c6e2b1a..fc8623d51a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -63,11 +63,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; -import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; -import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.dispatcher.model.*; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; @@ -126,7 +122,7 @@ void testDispatchSelectQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.BATCH.getText()); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -181,7 +177,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.BATCH.getText()); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -237,7 +233,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.BATCH.getText()); + tags.put("type", JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -372,7 +368,7 @@ void testDispatchIndexQuery() { tags.put("datasource", "my_glue"); tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.STREAMING.getText()); + tags.put("type", JobType.STREAMING.getText()); String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -430,7 +426,7 @@ void testDispatchWithPPLQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.BATCH.getText()); + tags.put("type", JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -485,7 +481,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.BATCH.getText()); + tags.put("type", JobType.BATCH.getText()); String query = "show tables"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -541,7 +537,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { tags.put("datasource", "my_glue"); tags.put("index", "flint_my_glue_default_http_logs_elb_and_requesturi_index"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.STREAMING.getText()); + tags.put("type", JobType.STREAMING.getText()); String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -600,7 +596,7 @@ void testDispatchMaterializedViewQuery() { tags.put("datasource", "my_glue"); tags.put("index", "flint_mv_1"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("job_type", JobType.STREAMING.getText()); + tags.put("type", JobType.STREAMING.getText()); String query = "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + " (auto_refresh = true)"; @@ -653,6 +649,168 @@ void testDispatchMaterializedViewQuery() { verifyNoInteractions(flintIndexMetadataReader); } + @Test + void testDispatchShowMVQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + String query = "SHOW MATERIALIZED VIEW IN mys3.default"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + + @Test + void testRefreshIndexQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + + @Test + void testDispatchDescribeIndexQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); + String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + any()))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); + } + @Test void testDispatchWithWrongURI() { when(dataSourceService.getRawDataSourceMetadata("my_glue")) @@ -903,15 +1061,15 @@ void testGetQueryResponseOfDropIndex() { @Test void testDropIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP INDEX size_year ON my_glue.default.http_logs"; - IndexDetails indexDetails = - IndexDetails.builder() + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() .indexName("size_year") .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build(); - when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); // auto_refresh == true @@ -940,7 +1098,7 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -950,14 +1108,14 @@ void testDropIndexQuery() throws ExecutionException, InterruptedException { @Test void testDropSkippingIndexQuery() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - IndexDetails indexDetails = - IndexDetails.builder() + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); - when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); when(flintIndexMetadata.isAutoRefresh()).thenReturn(true); @@ -984,7 +1142,7 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -995,14 +1153,14 @@ void testDropSkippingIndexQuery() throws ExecutionException, InterruptedExceptio void testDropSkippingIndexQueryAutoRefreshFalse() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - IndexDetails indexDetails = - IndexDetails.builder() + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); - when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -1023,7 +1181,7 @@ void testDropSkippingIndexQueryAutoRefreshFalse() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); @@ -1034,14 +1192,14 @@ void testDropSkippingIndexQueryAutoRefreshFalse() void testDropSkippingIndexQueryDeleteIndexException() throws ExecutionException, InterruptedException { String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; - IndexDetails indexDetails = - IndexDetails.builder() + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() .fullyQualifiedTableName(new FullyQualifiedTableName("my_glue.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); - when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.isAutoRefresh()).thenReturn(false); @@ -1063,7 +1221,7 @@ void testDropSkippingIndexQueryDeleteIndexException() TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(0)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.FAILED.toString(), dropIndexResult.getStatus()); @@ -1076,14 +1234,14 @@ void testDropSkippingIndexQueryDeleteIndexException() @Test void testDropMVQuery() throws ExecutionException, InterruptedException { String query = "DROP MATERIALIZED VIEW mv_1"; - IndexDetails indexDetails = - IndexDetails.builder() + IndexQueryDetails indexQueryDetails = + IndexQueryDetails.builder() .mvName("mv_1") - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .fullyQualifiedTableName(null) .indexType(FlintIndexType.MATERIALIZED_VIEW) .build(); - when(flintIndexMetadataReader.getFlintIndexMetadata(indexDetails)) + when(flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails)) .thenReturn(flintIndexMetadata); when(flintIndexMetadata.getJobId()).thenReturn(EMR_JOB_ID); // auto_refresh == true @@ -1112,7 +1270,7 @@ void testDropMVQuery() throws ExecutionException, InterruptedException { TEST_CLUSTER_NAME)); verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); - verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexDetails); + verify(flintIndexMetadataReader, times(1)).getFlintIndexMetadata(indexQueryDetails); SparkQueryDispatcher.DropIndexResult dropIndexResult = SparkQueryDispatcher.DropIndexResult.fromJobId(dispatchQueryResponse.getJobId()); Assertions.assertEquals(JobRunState.SUCCESS.toString(), dropIndexResult.getStatus()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 06a8d8c73c..14ccaf7708 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -13,7 +13,6 @@ import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; -import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -194,7 +193,7 @@ public static CreateSessionRequest createSessionRequest() { "appId", "arn", SparkSubmitParameters.Builder.builder(), - ImmutableMap.of(), + new HashMap<>(), "resultIndex", DS_NAME); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java index 3cc40e0df5..4d809c31dc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java @@ -25,7 +25,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @ExtendWith(MockitoExtension.class) public class FlintIndexMetadataReaderImplTest { @@ -44,10 +45,10 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - IndexDetails.builder() + IndexQueryDetails.builder() .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); @@ -64,11 +65,11 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata( - IndexDetails.builder() + IndexQueryDetails.builder() .indexName("cv1") .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build()); Assertions.assertEquals("00fdmvv9hp8u0o0q", indexMetadata.getJobId()); @@ -87,12 +88,12 @@ void testGetJobIDWithNPEException() { IllegalArgumentException.class, () -> flintIndexMetadataReader.getFlintIndexMetadata( - IndexDetails.builder() + IndexQueryDetails.builder() .indexName("cv1") .fullyQualifiedTableName( new FullyQualifiedTableName("mys3.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build())); Assertions.assertEquals("Provided Index doesn't exist", illegalArgumentException.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java similarity index 71% rename from spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java rename to spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java index cf6b5f8f2b..e725ddc21e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/IndexDetailsTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/IndexQueryDetailsTest.java @@ -9,18 +9,19 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; -public class IndexDetailsTest { +public class IndexQueryDetailsTest { @Test public void skippingIndexName() { assertEquals( "flint_mys3_default_http_logs_skipping_index", - IndexDetails.builder() + IndexQueryDetails.builder() .indexName("invalid") .fullyQualifiedTableName(new FullyQualifiedTableName("mys3.default.http_logs")) .autoRefresh(false) - .isDropIndex(true) + .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build() .openSearchIndexName()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 01759c2bdd..c86d7656d6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -15,7 +15,9 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; -import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; +import org.opensearch.sql.spark.flint.FlintIndexType; @ExtendWith(MockitoExtension.class) public class SQLQueryUtilsTest { @@ -25,13 +27,13 @@ void testExtractionOfTableNameFromSQLQueries() { String sqlQuery = "select * from my_glue.default.http_logs"; FullyQualifiedTableName fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "select * from my_glue.db.http_logs"; - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("db", fullyQualifiedTableName.getSchemaName()); @@ -39,28 +41,28 @@ void testExtractionOfTableNameFromSQLQueries() { sqlQuery = "select * from my_glue.http_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("my_glue", fullyQualifiedTableName.getSchemaName()); Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "select * from http_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "DROP TABLE myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); sqlQuery = "DESCRIBE TABLE myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); @@ -73,7 +75,7 @@ void testExtractionOfTableNameFromSQLQueries() { + "STORED AS file_format\n" + "LOCATION { 's3://bucket/folder/' }"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); @@ -92,7 +94,7 @@ void testErrorScenarios() { sqlQuery = "DESCRIBE TABLE FROM myS3.default.alb_logs"; fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); Assertions.assertEquals("FROM", fullyQualifiedTableName.getFullyQualifiedName()); Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("FROM", fullyQualifiedTableName.getTableName()); @@ -104,10 +106,12 @@ void testExtractionFromFlintIndexQueries() { String createCoveredIndexQuery = "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); - IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); - FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertEquals("elb_and_requestUri", indexDetails.getIndexName()); + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(createCoveredIndexQuery)); + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = + indexQueryDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("elb_and_requestUri", indexQueryDetails.getIndexName()); Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); @@ -118,12 +122,99 @@ void testExtractionFromFlintMVQuery() { String createCoveredIndexQuery = "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + " (auto_refresh = true)"; - Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); - IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(createCoveredIndexQuery)); + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = + indexQueryDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexQueryDetails.getIndexName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals("mv_1", indexQueryDetails.getMvName()); + } + + @Test + void testDescIndex() { + String descSkippingIndex = "DESC SKIPPING INDEX ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descSkippingIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(descSkippingIndex); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + + String descCoveringIndex = "DESC INDEX cv1 ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descCoveringIndex)); + indexDetails = SQLQueryUtils.extractIndexDetails(descCoveringIndex); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("cv1", indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + + String descMv = "DESC MATERIALIZED VIEW mv1"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descMv)); + indexDetails = SQLQueryUtils.extractIndexDetails(descMv); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertEquals("mv1", indexDetails.getMvName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + } + + @Test + void testShowIndex() { + String showCoveringIndex = " SHOW INDEX ON myS3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showCoveringIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(showCoveringIndex); FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNull(indexDetails.getMvName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + + String showMV = "SHOW MATERIALIZED VIEW IN my_glue.default"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showMV)); + indexDetails = SQLQueryUtils.extractIndexDetails(showMV); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNull(indexDetails.getMvName()); + Assertions.assertNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + } + + @Test + void testRefreshIndex() { + String refreshSkippingIndex = "REFRESH SKIPPING INDEX ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshSkippingIndex)); + IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(refreshSkippingIndex); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + + String refreshCoveringIndex = "REFRESH INDEX cv1 ON mys3.default.http_logs"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshCoveringIndex)); + indexDetails = SQLQueryUtils.extractIndexDetails(refreshCoveringIndex); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("cv1", indexDetails.getIndexName()); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + + String refreshMV = "REFRESH MATERIALIZED VIEW mv1"; + Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshMV)); + indexDetails = SQLQueryUtils.extractIndexDetails(refreshMV); + fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertNull(indexDetails.getIndexName()); + Assertions.assertEquals("mv1", indexDetails.getMvName()); Assertions.assertNull(fullyQualifiedTableName); - Assertions.assertEquals("mv_1", indexDetails.getMvName()); + Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); } /** https://github.com/opensearch-project/sql/issues/2206 */ From a5512f50771f165b6556d28eb8031b1ed918bc5d Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 25 Oct 2023 11:03:40 -0700 Subject: [PATCH 16/16] create new session if current session not ready (#2363) Signed-off-by: Peng Huo --- .../dispatcher/SparkQueryDispatcher.java | 8 ++- .../execution/session/InteractiveSession.java | 7 +++ .../sql/spark/execution/session/Session.java | 3 + ...AsyncQueryExecutorServiceImplSpecTest.java | 63 ++++++++++++++++++- .../dispatcher/SparkQueryDispatcherTest.java | 1 + 5 files changed, 78 insertions(+), 4 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 6ec67709b8..8feeddcafc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -213,7 +213,8 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); if (sessionManager.isEnabled()) { - Session session; + Session session = null; + if (dispatchQueryRequest.getSessionId() != null) { // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); @@ -222,8 +223,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ throw new IllegalArgumentException("no session found. " + sessionId); } session = createdSession.get(); - } else { - // create session if not exist + } + if (session == null || !session.isReady()) { + // create session if not exist or session dead/fail tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); session = sessionManager.createSession( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 956275b04a..3221b33b2c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,7 +6,9 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.DEAD; import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; +import static org.opensearch.sql.spark.execution.session.SessionState.FAIL; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -130,4 +132,9 @@ public Optional get(StatementId stID) { .statementModel(model) .build()); } + + @Override + public boolean isReady() { + return sessionModel.getSessionState() != DEAD && sessionModel.getSessionState() != FAIL; + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index 4d919d5e2e..d3d3411ded 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -37,4 +37,7 @@ public interface Session { SessionModel getSessionModel(); SessionId getSessionId(); + + /** return true if session is ready to use. */ + boolean isReady(); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f65049a7d9..6bc40c009b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -63,6 +63,7 @@ import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; @@ -390,6 +391,7 @@ public void withSessionCreateAsyncQueryFailed() { assertEquals("mock error", asyncQueryResults.getError()); } + // https://github.com/opensearch-project/sql/issues/2344 @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); @@ -419,6 +421,65 @@ public void createSessionMoreThanLimitFailed() { "The maximum number of active sessions can be supported is 1", exception.getMessage()); } + // https://github.com/opensearch-project/sql/issues/2360 + @Test + public void recreateSessionIfNotReady() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // set sessionState to FAIL + setSessionState(first.getSessionId(), SessionState.FAIL); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertNotEquals(first.getSessionId(), second.getSessionId()); + + // set sessionState to FAIL + setSessionState(second.getSessionId(), SessionState.DEAD); + + // 3. reuse session id + CreateAsyncQueryResponse third = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, second.getSessionId())); + assertNotEquals(second.getSessionId(), third.getSessionId()); + } + + @Test + public void submitQueryInInvalidSessionThrowException() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + SessionId sessionId = SessionId.newSessionId(DATASOURCE); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId()))); + assertEquals("no session found. " + sessionId, exception.getMessage()); + } + private DataSourceServiceImpl createDataSourceService() { String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = @@ -536,6 +597,6 @@ void setSessionState(String sessionId, SessionState sessionState) { Optional model = getSession(stateStore, DATASOURCE).apply(sessionId); SessionModel updated = updateSessionState(stateStore, DATASOURCE).apply(model.get(), sessionState); - assertEquals(SessionState.RUNNING, updated.getSessionState()); + assertEquals(sessionState, updated.getSessionState()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index fc8623d51a..743274d46c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -315,6 +315,7 @@ void testDispatchSelectQueryReuseSession() { doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); + when(session.isReady()).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata);