From e16da37e167298a251a4cbd7750612b8792f0129 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 25 Oct 2023 17:38:08 -0700 Subject: [PATCH] Create new session if client provided session is invalid (#2368) * Create new session if session is invalid Signed-off-by: Peng Huo * fix code style Signed-off-by: Peng Huo * fix UT Signed-off-by: Peng Huo * fix error response Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../dispatcher/SparkQueryDispatcher.java | 5 ++--- .../execution/statement/StatementModel.java | 2 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 22 +++++++++---------- .../dispatcher/SparkQueryDispatcherTest.java | 20 ----------------- 4 files changed, 14 insertions(+), 35 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 8feeddcafc..5e80259e09 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -219,10 +219,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); Optional createdSession = sessionManager.getSession(sessionId); - if (createdSession.isEmpty()) { - throw new IllegalArgumentException("no session found. " + sessionId); + if (createdSession.isPresent()) { + session = createdSession.get(); } - session = createdSession.get(); } if (session == null || !session.isReady()) { // create session if not exist or session dead/fail diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index 2a1043bf73..adc147c905 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -36,7 +36,7 @@ public class StatementModel extends StateModel { public static final String QUERY_ID = "queryId"; public static final String SUBMIT_TIME = "submitTime"; public static final String ERROR = "error"; - public static final String UNKNOWN = "unknown"; + public static final String UNKNOWN = ""; public static final String STATEMENT_DOC_TYPE = "statement"; private final String version; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 6bc40c009b..cf638effc6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -45,6 +45,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.plugins.Plugin; @@ -227,6 +228,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { // 2. fetch async query result. AsyncQueryExecutionResponse asyncQueryResults = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertTrue(Strings.isEmpty(asyncQueryResults.getError())); assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); // 3. cancel async query. @@ -460,7 +462,7 @@ public void recreateSessionIfNotReady() { } @Test - public void submitQueryInInvalidSessionThrowException() { + public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrsClient); @@ -468,16 +470,14 @@ public void submitQueryInInvalidSessionThrowException() { // enable session enableSession(true); - // 1. create async query. - SessionId sessionId = SessionId.newSessionId(DATASOURCE); - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest( - "select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId()))); - assertEquals("no session found. " + sessionId, exception.getMessage()); + // 1. create async query with invalid sessionId + SessionId invalidSessionId = SessionId.newSessionId(DATASOURCE); + CreateAsyncQueryResponse asyncQuery = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, invalidSessionId.getSessionId())); + assertNotNull(asyncQuery.getSessionId()); + assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId()); } private DataSourceServiceImpl createDataSourceService() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 743274d46c..95b6033d12 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -327,26 +327,6 @@ void testDispatchSelectQueryReuseSession() { Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } - @Test - void testDispatchSelectQueryInvalidSession() { - String query = "select * from my_glue.default.http_logs"; - DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid"); - - doReturn(true).when(sessionManager).isEnabled(); - doReturn(Optional.empty()).when(sessionManager).getSession(any()); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); - doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - IllegalArgumentException exception = - Assertions.assertThrows( - IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); - - verifyNoInteractions(emrServerlessClient); - verify(sessionManager, never()).createSession(any()); - Assertions.assertEquals( - "no session found. " + new SessionId("invalid"), exception.getMessage()); - } - @Test void testDispatchSelectQueryFailedCreateSession() { String query = "select * from my_glue.default.http_logs";