Skip to content

Commit

Permalink
Create new session if client provided session is invalid (#2368)
Browse files Browse the repository at this point in the history
* Create new session if session is invalid

Signed-off-by: Peng Huo <[email protected]>

* fix code style

Signed-off-by: Peng Huo <[email protected]>

* fix UT

Signed-off-by: Peng Huo <[email protected]>

* fix error response

Signed-off-by: Peng Huo <[email protected]>

---------

Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo authored Oct 26, 2023
1 parent bb82d85 commit e16da37
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
// get session from request
SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId());
Optional<Session> createdSession = sessionManager.getSession(sessionId);
if (createdSession.isEmpty()) {
throw new IllegalArgumentException("no session found. " + sessionId);
if (createdSession.isPresent()) {
session = createdSession.get();
}
session = createdSession.get();
}
if (session == null || !session.isReady()) {
// create session if not exist or session dead/fail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class StatementModel extends StateModel {
public static final String QUERY_ID = "queryId";
public static final String SUBMIT_TIME = "submitTime";
public static final String ERROR = "error";
public static final String UNKNOWN = "unknown";
public static final String UNKNOWN = "";
public static final String STATEMENT_DOC_TYPE = "statement";

private final String version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.plugins.Plugin;
Expand Down Expand Up @@ -227,6 +228,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() {
// 2. fetch async query result.
AsyncQueryExecutionResponse asyncQueryResults =
asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId());
assertTrue(Strings.isEmpty(asyncQueryResults.getError()));
assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus());

// 3. cancel async query.
Expand Down Expand Up @@ -460,24 +462,22 @@ public void recreateSessionIfNotReady() {
}

@Test
public void submitQueryInInvalidSessionThrowException() {
public void submitQueryInInvalidSessionWillCreateNewSession() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
SessionId sessionId = SessionId.newSessionId(DATASOURCE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId())));
assertEquals("no session found. " + sessionId, exception.getMessage());
// 1. create async query with invalid sessionId
SessionId invalidSessionId = SessionId.newSessionId(DATASOURCE);
CreateAsyncQueryResponse asyncQuery =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, invalidSessionId.getSessionId()));
assertNotNull(asyncQuery.getSessionId());
assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId());
}

private DataSourceServiceImpl createDataSourceService() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,26 +327,6 @@ void testDispatchSelectQueryReuseSession() {
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
void testDispatchSelectQueryInvalidSession() {
String query = "select * from my_glue.default.http_logs";
DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid");

doReturn(true).when(sessionManager).isEnabled();
doReturn(Optional.empty()).when(sessionManager).getSession(any());
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata);
doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata);
IllegalArgumentException exception =
Assertions.assertThrows(
IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest));

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).createSession(any());
Assertions.assertEquals(
"no session found. " + new SessionId("invalid"), exception.getMessage());
}

@Test
void testDispatchSelectQueryFailedCreateSession() {
String query = "select * from my_glue.default.http_logs";
Expand Down

0 comments on commit e16da37

Please sign in to comment.