Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed May 9, 2024
1 parent b7bb16a commit 8e7fc7c
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ public StatementId submit(QueryRequest request) {

@Override
public Optional<Statement> get(StatementId stID) {
return statementStorageService.getStatement(stID.getId(), sessionModel.getDatasourceName())
return statementStorageService
.getStatement(stID.getId(), sessionModel.getDatasourceName())
.map(
model ->
Statement.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ public Session createSession(CreateSessionRequest request) {
* empty Optional if no matching session is found.
*/
public Optional<Session> getSession(SessionId sid, String dataSourceName) {
Optional<SessionModel> model = sessionStorageService.getSession(sid.getSessionId(), dataSourceName);
Optional<SessionModel> model =
sessionStorageService.getSession(sid.getSessionId(), dataSourceName);
if (model.isPresent()) {
InteractiveSession session =
InteractiveSession.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ public void cancel() {
}
try {
this.statementModel =
statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED,
statementModel.getDatasourceName());
statementStorageService.updateStatementState(
statementModel, StatementState.CANCELLED, statementModel.getDatasourceName());
} catch (DocumentMissingException e) {
String errorMsg =
String.format("cancel statement failed. no statement found. statement: %s.", statementId);
LOG.error(errorMsg);
throw new IllegalStateException(errorMsg);
} catch (VersionConflictEngineException e) {
this.statementModel =
statementStorageService.getStatement(statementModel.getId(), statementModel.getDatasourceName())
statementStorageService
.getStatement(statementModel.getId(), statementModel.getDatasourceName())
.orElse(this.statementModel);
String errorMsg =
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,17 @@ public SessionManager sessionManager(
StatementStorageService statementStorageService,
EMRServerlessClientFactory emrServerlessClientFactory,
Settings settings) {
return new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, settings);
return new SessionManager(
sessionStorageService, statementStorageService, emrServerlessClientFactory, settings);
}

@Provides
public SessionStorageService sessionStorageService(
StateStore stateStore
) {
public SessionStorageService sessionStorageService(StateStore stateStore) {
return new OpenSearchSessionStorageService(stateStore);
}

@Provides
public StatementStorageService statementStorageService(
StateStore stateStore
) {
public StatementStorageService statementStorageService(StateStore stateStore) {
return new OpenSearchStatementStorageService(stateStore);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, pluginSettings),
new SessionManager(
sessionStorageService,
statementStorageService,
emrServerlessClientFactory,
pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore),
new FlintIndexOpFactory(
Expand All @@ -240,7 +244,11 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
this.dataSourceService,
new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, pluginSettings),
new SessionManager(
sessionStorageService,
statementStorageService,
emrServerlessClientFactory,
pluginSettings),
queryHandlerFactory);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
Expand Down Expand Up @@ -358,7 +366,8 @@ int search(QueryBuilder query) {

void setSessionState(String sessionId, SessionState sessionState) {
Optional<SessionModel> model = sessionStorageService.getSession(sessionId, MYS3_DATASOURCE);
SessionModel updated = sessionStorageService.updateSessionState(model.get(), sessionState, MYS3_DATASOURCE);
SessionModel updated =
sessionStorageService.updateSessionState(model.get(), sessionState, MYS3_DATASOURCE);
assertEquals(sessionState, updated.getSessionState());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStatementStorageService;
import org.opensearch.sql.spark.execution.statestore.SessionStorageService;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand All @@ -33,7 +33,8 @@
/** mock-maker-inline does not work with OpenSearchTestCase. */
public class InteractiveSessionTest extends OpenSearchIntegTestCase {

private static final String indexName = OpenSearchStateStoreUtil.getIndexName(TEST_DATASOURCE_NAME);
private static final String indexName =
OpenSearchStateStoreUtil.getIndexName(TEST_DATASOURCE_NAME);

private TestEMRServerlessClient emrsClient;
private StartJobRequest startJobRequest;
Expand All @@ -49,7 +50,12 @@ public void setup() {
sessionStorageService = new OpenSearchSessionStorageService(stateStore);
statementStorageService = new OpenSearchStatementStorageService(stateStore);
EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient;
sessionManager = new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, sessionSetting());
sessionManager =
new SessionManager(
sessionStorageService,
statementStorageService,
emrServerlessClientFactory,
sessionSetting());
}

@After
Expand All @@ -71,7 +77,8 @@ public void openCloseSession() {
.build();

SessionAssertions assertions = new SessionAssertions(session);
assertions.open(createSessionRequest())
assertions
.open(createSessionRequest())
.assertSessionState(NOT_STARTED)
.assertAppId("appId")
.assertJobId("jobId");
Expand Down Expand Up @@ -132,7 +139,8 @@ public void closeNotExistSession() {
public void sessionManagerCreateSession() {
Session session = sessionManager.createSession(createSessionRequest());

new SessionAssertions(session).assertSessionState(NOT_STARTED)
new SessionAssertions(session)
.assertSessionState(NOT_STARTED)
.assertAppId("appId")
.assertJobId("jobId");
}
Expand Down Expand Up @@ -188,6 +196,4 @@ public SessionAssertions close() {
return this;
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.statestore.SessionStorageService;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.execution.statestore.StatementStorageService;

@ExtendWith(MockitoExtension.class)
Expand All @@ -27,7 +26,12 @@ public class SessionManagerTest {

@Test
public void sessionEnable() {
SessionManager sessionManager = new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, sessionSetting());
SessionManager sessionManager =
new SessionManager(
sessionStorageService,
statementStorageService,
emrServerlessClientFactory,
sessionSetting());

Assertions.assertTrue(sessionManager.isEnabled());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ public void setup() {
sessionStorageService = new OpenSearchSessionStorageService(stateStore);
EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient;

sessionManager = new SessionManager(sessionStorageService, statementStorageService, emrServerlessClientFactory, sessionSetting());
sessionManager =
new SessionManager(
sessionStorageService,
statementStorageService,
emrServerlessClientFactory,
sessionSetting());
}

@After
Expand Down Expand Up @@ -123,12 +128,12 @@ public void cancelNotExistStatement() {
@Test
public void cancelFailedBecauseOfConflict() {
StatementId stId = new StatementId("statementId");
Statement st =
buildStatement(stId);
Statement st = buildStatement(stId);
st.open();

StatementModel running = statementStorageService.updateStatementState(st.getStatementModel(), CANCELLED,
TEST_DATASOURCE_NAME);
StatementModel running =
statementStorageService.updateStatementState(
st.getStatementModel(), CANCELLED, TEST_DATASOURCE_NAME);

assertEquals(StatementState.CANCELLED, running.getStatementState());

Expand Down Expand Up @@ -202,8 +207,7 @@ public void cancelCancelledStatementFailed() {

@Test
public void cancelRunningStatementSuccess() {
Statement st =
buildStatement();
Statement st = buildStatement();

// submit statement
TestStatement testStatement = testStatement(st, statementStorageService);
Expand All @@ -223,8 +227,8 @@ public void submitStatementInRunningSession() {
Session session = sessionManager.createSession(createSessionRequest());

// App change state to running
sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING,
TEST_DATASOURCE_NAME);
sessionStorageService.updateSessionState(
session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME);

StatementId statementId = session.submit(queryRequest());
assertFalse(statementId.getId().isEmpty());
Expand All @@ -242,8 +246,8 @@ public void submitStatementInNotStartedState() {
public void failToSubmitStatementInDeadState() {
Session session = sessionManager.createSession(createSessionRequest());

sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.DEAD,
TEST_DATASOURCE_NAME);
sessionStorageService.updateSessionState(
session.getSessionModel(), SessionState.DEAD, TEST_DATASOURCE_NAME);

IllegalStateException exception =
assertThrows(IllegalStateException.class, () -> session.submit(queryRequest()));
Expand All @@ -257,8 +261,8 @@ public void failToSubmitStatementInDeadState() {
public void failToSubmitStatementInFailState() {
Session session = sessionManager.createSession(createSessionRequest());

sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.FAIL,
TEST_DATASOURCE_NAME);
sessionStorageService.updateSessionState(
session.getSessionModel(), SessionState.FAIL, TEST_DATASOURCE_NAME);

IllegalStateException exception =
assertThrows(IllegalStateException.class, () -> session.submit(queryRequest()));
Expand All @@ -270,9 +274,7 @@ public void failToSubmitStatementInFailState() {

@Test
public void newStatementFieldAssert() {
Session session =
sessionManager
.createSession(createSessionRequest());
Session session = sessionManager.createSession(createSessionRequest());
StatementId statementId = session.submit(queryRequest());
Optional<Statement> statement = session.get(statementId);

Expand Down Expand Up @@ -305,8 +307,8 @@ public void failToSubmitStatementInDeletedSession() {
public void getStatementSuccess() {
Session session = sessionManager.createSession(createSessionRequest());
// App change state to running
sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING,
TEST_DATASOURCE_NAME);
sessionStorageService.updateSessionState(
session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME);
StatementId statementId = session.submit(queryRequest());

Optional<Statement> statement = session.get(statementId);
Expand All @@ -317,11 +319,10 @@ public void getStatementSuccess() {

@Test
public void getStatementNotExist() {
Session session = sessionManager
.createSession(createSessionRequest());
Session session = sessionManager.createSession(createSessionRequest());
// App change state to running
sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING,
TEST_DATASOURCE_NAME);
sessionStorageService.updateSessionState(
session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME);

Optional<Statement> statement = session.get(StatementId.newStatementId("not-exist-id"));
assertFalse(statement.isPresent());
Expand All @@ -332,16 +333,16 @@ static class TestStatement {
private final Statement st;
private final StatementStorageService statementStorageService;

public static TestStatement testStatement(Statement st, StatementStorageService statementStorageService) {
public static TestStatement testStatement(
Statement st, StatementStorageService statementStorageService) {
return new TestStatement(st, statementStorageService);
}

public TestStatement assertSessionState(StatementState expected) {
assertEquals(expected, st.getStatementModel().getStatementState());

Optional<StatementModel> model = statementStorageService.getStatement(
st.getStatementId().getId(), TEST_DATASOURCE_NAME
);
Optional<StatementModel> model =
statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME);
assertTrue(model.isPresent());
assertEquals(expected, model.get().getStatementState());

Expand All @@ -351,9 +352,8 @@ public TestStatement assertSessionState(StatementState expected) {
public TestStatement assertStatementId(StatementId expected) {
assertEquals(expected, st.getStatementModel().getStatementId());

Optional<StatementModel> model = statementStorageService.getStatement(
st.getStatementId().getId(), TEST_DATASOURCE_NAME
);
Optional<StatementModel> model =
statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME);
assertTrue(model.isPresent());
assertEquals(expected, model.get().getStatementId());
return this;
Expand All @@ -370,15 +370,17 @@ public TestStatement cancel() {
}

public TestStatement run() {
StatementModel model = statementStorageService.updateStatementState(st.getStatementModel(), RUNNING,
TEST_DATASOURCE_NAME);
StatementModel model =
statementStorageService.updateStatementState(
st.getStatementModel(), RUNNING, TEST_DATASOURCE_NAME);
st.setStatementModel(model);
return this;
}
}

private QueryRequest queryRequest() {
return new QueryRequest(AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME), LangType.SQL, "select 1");
return new QueryRequest(
AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME), LangType.SQL, "select 1");
}

private Statement createStatement(StatementId stId) {
Expand Down

0 comments on commit 8e7fc7c

Please sign in to comment.