From bffd6b154e612ce94aadcd2796edcb635467e7b6 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Mon, 6 May 2024 13:26:48 -0700 Subject: [PATCH 1/4] Refactor IndexDMLHandler and related classes (#2644) Signed-off-by: Tomoyuki Morita (cherry picked from commit 45122ec67bfb517ea6ab445624fa82b8a1663a4f) --- .../org/opensearch/sql/plugin/SQLPlugin.java | 6 +- .../cluster/ClusterManagerEventListener.java | 17 +- .../FlintStreamingJobHouseKeeperTask.java | 21 +- .../sql/spark/dispatcher/IndexDMLHandler.java | 64 +- .../spark/dispatcher/QueryHandlerFactory.java | 17 +- .../spark/dispatcher/RefreshQueryHandler.java | 16 +- .../flint/IndexDMLResultStorageService.java | 12 + ...penSearchIndexDMLResultStorageService.java | 25 + .../spark/flint/operation/FlintIndexOp.java | 6 +- .../flint/operation/FlintIndexOpAlter.java | 10 +- .../flint/operation/FlintIndexOpCancel.java | 13 +- .../flint/operation/FlintIndexOpDrop.java | 13 +- .../flint/operation/FlintIndexOpFactory.java | 42 ++ .../flint/operation/FlintIndexOpVacuum.java | 9 +- .../config/AsyncExecutorServiceModule.java | 27 +- .../AsyncQueryExecutorServiceSpec.java | 33 +- .../FlintStreamingJobHouseKeeperTaskTest.java | 648 +++++++----------- .../spark/dispatcher/IndexDMLHandlerTest.java | 26 +- .../dispatcher/SparkQueryDispatcherTest.java | 27 +- .../flint/operation/FlintIndexOpTest.java | 18 +- 20 files changed, 486 insertions(+), 564 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index bc0a084f8c..16fd46c253 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -79,10 +79,9 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; @@ -227,8 +226,7 @@ public Collection createComponents( environment.settings(), dataSourceService, injector.getInstance(FlintIndexMetadataServiceImpl.class), - injector.getInstance(StateStore.class), - injector.getInstance(EMRServerlessClientFactory.class)); + injector.getInstance(FlintIndexOpFactory.class)); return ImmutableList.of( dataSourceService, injector.getInstance(AsyncQueryExecutorService.class), diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java index f04c6cb830..6c660f073c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java @@ -21,9 +21,8 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; @@ -37,8 +36,7 @@ public class ClusterManagerEventListener implements LocalNodeClusterManagerListe private Clock clock; private DataSourceService dataSourceService; private FlintIndexMetadataService flintIndexMetadataService; - private StateStore stateStore; - private EMRServerlessClientFactory emrServerlessClientFactory; + private FlintIndexOpFactory flintIndexOpFactory; private Duration sessionTtlDuration; private Duration resultTtlDuration; private TimeValue streamingJobHouseKeepingInterval; @@ -56,8 +54,7 @@ public ClusterManagerEventListener( Settings settings, DataSourceService dataSourceService, FlintIndexMetadataService flintIndexMetadataService, - StateStore stateStore, - EMRServerlessClientFactory emrServerlessClientFactory) { + FlintIndexOpFactory flintIndexOpFactory) { this.clusterService = clusterService; this.threadPool = threadPool; this.client = client; @@ -65,8 +62,7 @@ public ClusterManagerEventListener( this.clock = clock; this.dataSourceService = dataSourceService; this.flintIndexMetadataService = flintIndexMetadataService; - this.stateStore = stateStore; - this.emrServerlessClientFactory = emrServerlessClientFactory; + this.flintIndexOpFactory = flintIndexOpFactory; this.sessionTtlDuration = toDuration(sessionTtl.get(settings)); this.resultTtlDuration = toDuration(resultTtl.get(settings)); this.streamingJobHouseKeepingInterval = streamingJobHouseKeepingInterval.get(settings); @@ -151,10 +147,7 @@ private void initializeStreamingJobHouseKeeperCron() { flintStreamingJobHouseKeeperCron = threadPool.scheduleWithFixedDelay( new FlintStreamingJobHouseKeeperTask( - dataSourceService, - flintIndexMetadataService, - stateStore, - emrServerlessClientFactory), + dataSourceService, flintIndexMetadataService, flintIndexOpFactory), streamingJobHouseKeepingInterval, executorName()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java index 27221f1b72..31b1ecb49c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java @@ -17,13 +17,10 @@ import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpAlter; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpDrop; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; /** Cleaner task which alters the active streaming jobs of a disabled datasource. */ @RequiredArgsConstructor @@ -31,8 +28,7 @@ public class FlintStreamingJobHouseKeeperTask implements Runnable { private final DataSourceService dataSourceService; private final FlintIndexMetadataService flintIndexMetadataService; - private final StateStore stateStore; - private final EMRServerlessClientFactory emrServerlessClientFactory; + private final FlintIndexOpFactory flintIndexOpFactory; private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobHouseKeeperTask.class); protected static final AtomicBoolean isRunning = new AtomicBoolean(false); @@ -95,9 +91,7 @@ private void dropAutoRefreshIndex( String autoRefreshIndex, FlintIndexMetadata flintIndexMetadata, String datasourceName) { // When the datasource is deleted. Possibly Replace with VACUUM Operation. LOGGER.info("Attempting to drop auto refresh index: {}", autoRefreshIndex); - FlintIndexOpDrop flintIndexOpDrop = - new FlintIndexOpDrop(stateStore, datasourceName, emrServerlessClientFactory.getClient()); - flintIndexOpDrop.apply(flintIndexMetadata); + flintIndexOpFactory.getDrop(datasourceName).apply(flintIndexMetadata); LOGGER.info("Successfully dropped index: {}", autoRefreshIndex); } @@ -106,14 +100,7 @@ private void alterAutoRefreshIndex( LOGGER.info("Attempting to alter index: {}", autoRefreshIndex); FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); - FlintIndexOpAlter flintIndexOpAlter = - new FlintIndexOpAlter( - flintIndexOptions, - stateStore, - datasourceName, - emrServerlessClientFactory.getClient(), - flintIndexMetadataService); - flintIndexOpAlter.apply(flintIndexMetadata); + flintIndexOpFactory.getAlter(flintIndexOptions, datasourceName).apply(flintIndexMetadata); LOGGER.info("Successfully altered index: {}", autoRefreshIndex); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 412db50e85..dfd5316f6c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createIndexDMLResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.util.Map; @@ -16,24 +15,20 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.json.JSONObject; -import org.opensearch.client.Client; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.flint.operation.FlintIndexOp; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpAlter; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpDrop; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpVacuum; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** Handle Index DML query. includes * DROP * ALT? */ @@ -45,15 +40,10 @@ public class IndexDMLHandler extends AsyncQueryHandler { public static final String DROP_INDEX_JOB_ID = "dropIndexJobId"; public static final String DML_QUERY_JOB_ID = "DMLQueryJobId"; - private final EMRServerlessClient emrServerlessClient; - private final JobExecutionResponseReader jobExecutionResponseReader; - private final FlintIndexMetadataService flintIndexMetadataService; - - private final StateStore stateStore; - - private final Client client; + private final IndexDMLResultStorageService indexDMLResultStorageService; + private final FlintIndexOpFactory flintIndexOpFactory; public static boolean isIndexDMLQuery(String jobId) { return DROP_INDEX_JOB_ID.equalsIgnoreCase(jobId) || DML_QUERY_JOB_ID.equalsIgnoreCase(jobId); @@ -67,14 +57,16 @@ public DispatchQueryResponse submit( try { IndexQueryDetails indexDetails = context.getIndexQueryDetails(); FlintIndexMetadata indexMetadata = getFlintIndexMetadata(indexDetails); - executeIndexOp(dispatchQueryRequest, indexDetails, indexMetadata); + + getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); + AsyncQueryId asyncQueryId = storeIndexDMLResult( dispatchQueryRequest, dataSourceMetadata, JobRunState.SUCCESS.toString(), StringUtils.EMPTY, - startTime); + getElapsedTimeSince(startTime)); return new DispatchQueryResponse( asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); } catch (Exception e) { @@ -85,7 +77,7 @@ public DispatchQueryResponse submit( dataSourceMetadata, JobRunState.FAILED.toString(), e.getMessage(), - startTime); + getElapsedTimeSince(startTime)); return new DispatchQueryResponse( asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); } @@ -96,7 +88,7 @@ private AsyncQueryId storeIndexDMLResult( DataSourceMetadata dataSourceMetadata, String status, String error, - long startTime) { + long queryRunTime) { AsyncQueryId asyncQueryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); IndexDMLResult indexDMLResult = new IndexDMLResult( @@ -104,38 +96,26 @@ private AsyncQueryId storeIndexDMLResult( status, error, dispatchQueryRequest.getDatasource(), - System.currentTimeMillis() - startTime, + queryRunTime, System.currentTimeMillis()); - createIndexDMLResult(stateStore, dataSourceMetadata.getResultIndex()).apply(indexDMLResult); + indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, dataSourceMetadata.getName()); return asyncQueryId; } - private void executeIndexOp( - DispatchQueryRequest dispatchQueryRequest, - IndexQueryDetails indexQueryDetails, - FlintIndexMetadata indexMetadata) { + private long getElapsedTimeSince(long startTime) { + return System.currentTimeMillis() - startTime; + } + + private FlintIndexOp getIndexOp( + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { switch (indexQueryDetails.getIndexQueryActionType()) { case DROP: - FlintIndexOp dropOp = - new FlintIndexOpDrop( - stateStore, dispatchQueryRequest.getDatasource(), emrServerlessClient); - dropOp.apply(indexMetadata); - break; + return flintIndexOpFactory.getDrop(dispatchQueryRequest.getDatasource()); case ALTER: - FlintIndexOpAlter flintIndexOpAlter = - new FlintIndexOpAlter( - indexQueryDetails.getFlintIndexOptions(), - stateStore, - dispatchQueryRequest.getDatasource(), - emrServerlessClient, - flintIndexMetadataService); - flintIndexOpAlter.apply(indexMetadata); - break; + return flintIndexOpFactory.getAlter( + indexQueryDetails.getFlintIndexOptions(), dispatchQueryRequest.getDatasource()); case VACUUM: - FlintIndexOp indexVacuumOp = - new FlintIndexOpVacuum(stateStore, dispatchQueryRequest.getDatasource(), client); - indexVacuumOp.apply(indexMetadata); - break; + return flintIndexOpFactory.getVacuum(dispatchQueryRequest.getDatasource()); default: throw new IllegalStateException( String.format( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index 1713bed4e2..f994d9c728 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -6,11 +6,11 @@ package org.opensearch.sql.spark.dispatcher; import lombok.RequiredArgsConstructor; -import org.opensearch.client.Client; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -19,10 +19,10 @@ public class QueryHandlerFactory { private final JobExecutionResponseReader jobExecutionResponseReader; private final FlintIndexMetadataService flintIndexMetadataService; - private final Client client; private final SessionManager sessionManager; private final LeaseManager leaseManager; - private final StateStore stateStore; + private final IndexDMLResultStorageService indexDMLResultStorageService; + private final FlintIndexOpFactory flintIndexOpFactory; private final EMRServerlessClientFactory emrServerlessClientFactory; public RefreshQueryHandler getRefreshQueryHandler() { @@ -30,8 +30,8 @@ public RefreshQueryHandler getRefreshQueryHandler() { emrServerlessClientFactory.getClient(), jobExecutionResponseReader, flintIndexMetadataService, - stateStore, - leaseManager); + leaseManager, + flintIndexOpFactory); } public StreamingQueryHandler getStreamingQueryHandler() { @@ -50,10 +50,9 @@ public InteractiveQueryHandler getInteractiveQueryHandler() { public IndexDMLHandler getIndexDMLHandler() { return new IndexDMLHandler( - emrServerlessClientFactory.getClient(), jobExecutionResponseReader, flintIndexMetadataService, - stateStore, - client); + indexDMLResultStorageService, + flintIndexOpFactory); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index d55408f62e..aeb5c1b35f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -13,11 +13,10 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.JobType; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.operation.FlintIndexOp; -import org.opensearch.sql.spark.flint.operation.FlintIndexOpCancel; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -25,19 +24,17 @@ public class RefreshQueryHandler extends BatchQueryHandler { private final FlintIndexMetadataService flintIndexMetadataService; - private final StateStore stateStore; - private final EMRServerlessClient emrServerlessClient; + private final FlintIndexOpFactory flintIndexOpFactory; public RefreshQueryHandler( EMRServerlessClient emrServerlessClient, JobExecutionResponseReader jobExecutionResponseReader, FlintIndexMetadataService flintIndexMetadataService, - StateStore stateStore, - LeaseManager leaseManager) { + LeaseManager leaseManager, + FlintIndexOpFactory flintIndexOpFactory) { super(emrServerlessClient, jobExecutionResponseReader, leaseManager); this.flintIndexMetadataService = flintIndexMetadataService; - this.stateStore = stateStore; - this.emrServerlessClient = emrServerlessClient; + this.flintIndexOpFactory = flintIndexOpFactory; } @Override @@ -51,8 +48,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { "Couldn't fetch flint index: %s details", asyncQueryJobMetadata.getIndexName())); } FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); - FlintIndexOp jobCancelOp = - new FlintIndexOpCancel(stateStore, datasourceName, emrServerlessClient); + FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); jobCancelOp.apply(indexMetadata); return asyncQueryJobMetadata.getQueryId().getId(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java new file mode 100644 index 0000000000..4a046564f5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; + +public interface IndexDMLResultStorageService { + IndexDMLResult createIndexDMLResult(IndexDMLResult result, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java new file mode 100644 index 0000000000..eeb2921449 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@RequiredArgsConstructor +public class OpenSearchIndexDMLResultStorageService implements IndexDMLResultStorageService { + + private final DataSourceService dataSourceService; + private final StateStore stateStore; + + @Override + public IndexDMLResult createIndexDMLResult(IndexDMLResult result, String datasourceName) { + DataSourceMetadata dataSourceMetadata = dataSourceService.getDataSourceMetadata(datasourceName); + return stateStore.create(result, IndexDMLResult::copy, dataSourceMetadata.getResultIndex()); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 8d5e301631..edfd0aace2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -21,6 +21,7 @@ import org.jetbrains.annotations.NotNull; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -33,6 +34,7 @@ public abstract class FlintIndexOp { private final StateStore stateStore; private final String datasourceName; + private final EMRServerlessClientFactory emrServerlessClientFactory; /** Apply operation on {@link FlintIndexMetadata} */ public void apply(FlintIndexMetadata metadata) { @@ -140,11 +142,11 @@ private void commit(FlintIndexStateModel flintIndex) { /*** * Common operation between AlterOff and Drop. So moved to FlintIndexOp. */ - public void cancelStreamingJob( - EMRServerlessClient emrServerlessClient, FlintIndexStateModel flintIndexStateModel) + public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) throws InterruptedException, TimeoutException { String applicationId = flintIndexStateModel.getApplicationId(); String jobId = flintIndexStateModel.getJobId(); + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); try { emrServerlessClient.cancelJobRun( flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true); diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 7db4f6a4c6..31e33539a1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -8,7 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -22,7 +22,6 @@ */ public class FlintIndexOpAlter extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(FlintIndexOpAlter.class); - private final EMRServerlessClient emrServerlessClient; private final FlintIndexMetadataService flintIndexMetadataService; private final FlintIndexOptions flintIndexOptions; @@ -30,10 +29,9 @@ public FlintIndexOpAlter( FlintIndexOptions flintIndexOptions, StateStore stateStore, String datasourceName, - EMRServerlessClient emrServerlessClient, + EMRServerlessClientFactory emrServerlessClientFactory, FlintIndexMetadataService flintIndexMetadataService) { - super(stateStore, datasourceName); - this.emrServerlessClient = emrServerlessClient; + super(stateStore, datasourceName, emrServerlessClientFactory); this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOptions = flintIndexOptions; } @@ -55,7 +53,7 @@ void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintInde "Running alter index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); this.flintIndexMetadataService.updateIndexToManualRefresh( flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions); - cancelStreamingJob(emrServerlessClient, flintIndexStateModel); + cancelStreamingJob(flintIndexStateModel); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 2317c5b6dc..0962e2a16b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -8,7 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -18,12 +18,11 @@ public class FlintIndexOpCancel extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); - private final EMRServerlessClient emrServerlessClient; - public FlintIndexOpCancel( - StateStore stateStore, String datasourceName, EMRServerlessClient emrServerlessClient) { - super(stateStore, datasourceName); - this.emrServerlessClient = emrServerlessClient; + StateStore stateStore, + String datasourceName, + EMRServerlessClientFactory emrServerlessClientFactory) { + super(stateStore, datasourceName, emrServerlessClientFactory); } // Only in refreshing state, the job is cancellable in case of REFRESH query. @@ -43,7 +42,7 @@ void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintInde LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); - cancelStreamingJob(emrServerlessClient, flintIndexStateModel); + cancelStreamingJob(flintIndexStateModel); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 586c346863..0f71b3bc70 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -8,7 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -17,12 +17,11 @@ public class FlintIndexOpDrop extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); - private final EMRServerlessClient emrServerlessClient; - public FlintIndexOpDrop( - StateStore stateStore, String datasourceName, EMRServerlessClient emrServerlessClient) { - super(stateStore, datasourceName); - this.emrServerlessClient = emrServerlessClient; + StateStore stateStore, + String datasourceName, + EMRServerlessClientFactory emrServerlessClientFactory) { + super(stateStore, datasourceName, emrServerlessClientFactory); } public boolean validate(FlintIndexState state) { @@ -44,7 +43,7 @@ void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintInde LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); - cancelStreamingJob(emrServerlessClient, flintIndexStateModel); + cancelStreamingJob(flintIndexStateModel); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java new file mode 100644 index 0000000000..6fc2261ade --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint.operation; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; + +@RequiredArgsConstructor +public class FlintIndexOpFactory { + private final StateStore stateStore; + private final Client client; + private final FlintIndexMetadataService flintIndexMetadataService; + private final EMRServerlessClientFactory emrServerlessClientFactory; + + public FlintIndexOpDrop getDrop(String datasource) { + return new FlintIndexOpDrop(stateStore, datasource, emrServerlessClientFactory); + } + + public FlintIndexOpAlter getAlter(FlintIndexOptions flintIndexOptions, String datasource) { + return new FlintIndexOpAlter( + flintIndexOptions, + stateStore, + datasource, + emrServerlessClientFactory, + flintIndexMetadataService); + } + + public FlintIndexOpVacuum getVacuum(String datasource) { + return new FlintIndexOpVacuum(stateStore, datasource, client, emrServerlessClientFactory); + } + + public FlintIndexOpCancel getCancel(String datasource) { + return new FlintIndexOpCancel(stateStore, datasource, emrServerlessClientFactory); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index cf204450e7..4287d9c7c9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -10,6 +10,7 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -23,8 +24,12 @@ public class FlintIndexOpVacuum extends FlintIndexOp { /** OpenSearch client. */ private final Client client; - public FlintIndexOpVacuum(StateStore stateStore, String datasourceName, Client client) { - super(stateStore, datasourceName); + public FlintIndexOpVacuum( + StateStore stateStore, + String datasourceName, + Client client, + EMRServerlessClientFactory emrServerlessClientFactory) { + super(stateStore, datasourceName, emrServerlessClientFactory); this.client = client; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index f93d065855..1d890ce346 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -30,6 +30,9 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -76,21 +79,37 @@ public SparkQueryDispatcher sparkQueryDispatcher( public QueryHandlerFactory queryhandlerFactory( JobExecutionResponseReader jobExecutionResponseReader, FlintIndexMetadataServiceImpl flintIndexMetadataReader, - NodeClient client, SessionManager sessionManager, DefaultLeaseManager defaultLeaseManager, - StateStore stateStore, + IndexDMLResultStorageService indexDMLResultStorageService, + FlintIndexOpFactory flintIndexOpFactory, EMRServerlessClientFactory emrServerlessClientFactory) { return new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataReader, - client, sessionManager, defaultLeaseManager, - stateStore, + indexDMLResultStorageService, + flintIndexOpFactory, emrServerlessClientFactory); } + @Provides + public FlintIndexOpFactory flintIndexOpFactory( + StateStore stateStore, + NodeClient client, + FlintIndexMetadataServiceImpl flintIndexMetadataService, + EMRServerlessClientFactory emrServerlessClientFactory) { + return new FlintIndexOpFactory( + stateStore, client, flintIndexMetadataService, emrServerlessClientFactory); + } + + @Provides + public IndexDMLResultStorageService indexDMLResultStorageService( + DataSourceService dataSourceService, StateStore stateStore) { + return new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore); + } + @Provides public SessionManager sessionManager( StateStore stateStore, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index fdd094259f..b1c7f68388 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -64,14 +64,18 @@ import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { + public static final String MYS3_DATASOURCE = "mys3"; public static final String MYGLUE_DATASOURCE = "my_glue"; @@ -81,6 +85,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected DataSourceServiceImpl dataSourceService; protected StateStore stateStore; protected ClusterSettings clusterSettings; + protected FlintIndexMetadataService flintIndexMetadataService; @Override protected Collection> nodePlugins() { @@ -88,6 +93,7 @@ protected Collection> nodePlugins() { } public static class TestSettingPlugin extends Plugin { + @Override public List> getSettings() { return OpenSearchSettings.pluginSettings(); @@ -148,6 +154,13 @@ public void setup() { stateStore = new StateStore(client, clusterService); createIndexWithMappings(dm.getResultIndex(), loadResultIndexMappings()); createIndexWithMappings(otherDm.getResultIndex(), loadResultIndexMappings()); + flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + } + + protected FlintIndexOpFactory getFlintIndexOpFactory( + EMRServerlessClientFactory emrServerlessClientFactory) { + return new FlintIndexOpFactory( + stateStore, client, flintIndexMetadataService, emrServerlessClientFactory); } @After @@ -205,10 +218,14 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new QueryHandlerFactory( jobExecutionResponseReader, new FlintIndexMetadataServiceImpl(client), - client, new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), - stateStore, + new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore), + new FlintIndexOpFactory( + stateStore, + client, + new FlintIndexMetadataServiceImpl(client), + emrServerlessClientFactory), emrServerlessClientFactory); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( @@ -269,6 +286,17 @@ public void setJobState(JobRunState jobState) { } } + protected LocalEMRSClient getCancelledLocalEmrsClient() { + return new LocalEMRSClient() { + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + super.getJobRunResult(applicationId, jobId); + JobRun jobRun = new JobRun(); + jobRun.setState("cancelled"); + return new GetJobRunResult().withJobRun(jobRun); + } + }; + } + public static class LocalEMRServerlessClientFactory implements EMRServerlessClientFactory { @Override @@ -333,6 +361,7 @@ public String loadResultIndexMappings() { @RequiredArgsConstructor public class FlintDatasetMock { + final String query; final String refreshQuery; final FlintIndexType indexType; diff --git a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index 80542ba2e0..6bcf9c6308 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -21,7 +21,6 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceSpec; import org.opensearch.sql.spark.asyncquery.model.MockFlintIndex; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -34,70 +33,40 @@ public class FlintStreamingJobHouseKeeperTaskTest extends AsyncQueryExecutorServ @Test @SneakyThrows public void testStreamingJobHouseKeeperWhenDataSourceDisabled() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); changeDataSourceStatus(MYGLUE_DATASOURCE, DISABLED); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.ACTIVE); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("false", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.ACTIVE); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("false", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(3); emrsClient.startJobRunCalled(0); @@ -108,64 +77,74 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { .getValue()); } + private ImmutableList getMockFlintIndices() { + return ImmutableList.of(getSkipping(), getCovering(), getMv()); + } + + private MockFlintIndex getMv() { + return new MockFlintIndex( + client, + "flint_my_glue_mydb_mv", + FlintIndexType.MATERIALIZED_VIEW, + "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," + + " incremental_refresh=true, output_mode=\"complete\") "); + } + + private MockFlintIndex getCovering() { + return new MockFlintIndex( + client, + "flint_my_glue_mydb_http_logs_covering_index", + FlintIndexType.COVERING, + "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," + + " incremental_refresh=true, output_mode=\"complete\")"); + } + + private MockFlintIndex getSkipping() { + return new MockFlintIndex( + client, + "flint_my_glue_mydb_http_logs_skipping_index", + FlintIndexType.SKIPPING, + "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," + + " incremental_refresh=true, output_mode=\"complete\")"); + } + @Test @SneakyThrows public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); changeDataSourceStatus(MYGLUE_DATASOURCE, DISABLED); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.REFRESHING); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("false", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.REFRESHING); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("false", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(9); emrsClient.startJobRunCalled(0); @@ -179,62 +158,41 @@ public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { @Test @SneakyThrows public void testSimulateConcurrentJobHouseKeeperExecution() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); changeDataSourceStatus(MYGLUE_DATASOURCE, DISABLED); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); FlintStreamingJobHouseKeeperTask.isRunning.compareAndSet(false, true); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.REFRESHING); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("true", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.REFRESHING); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("true", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(0); emrsClient.getJobRunResultCalled(0); emrsClient.startJobRunCalled(0); @@ -249,70 +207,40 @@ public void testSimulateConcurrentJobHouseKeeperExecution() { @SneakyThrows @Test public void testStreamingJobClearnerWhenDataSourceIsDeleted() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); this.dataSourceService.deleteDataSource(MYGLUE_DATASOURCE); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.DELETED); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("true", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.DELETED); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("true", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(3); emrsClient.startJobRunCalled(0); @@ -326,69 +254,39 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { @Test @SneakyThrows public void testStreamingJobHouseKeeperWhenDataSourceIsNeitherDisabledNorDeleted() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.REFRESHING); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("true", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.REFRESHING); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("true", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(0); emrsClient.getJobRunResultCalled(0); emrsClient.startJobRunCalled(0); @@ -413,14 +311,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); + emrsClient.getJobRunResultCalled(0); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(0); @@ -438,24 +337,16 @@ public void testStreamingJobHouseKeeperWhenFlintIndexIsCorrupted() throws Interr new MockFlintIndex(client(), indexName, FlintIndexType.COVERING, null); mockFlintIndex.createIndex(); changeDataSourceStatus(MYGLUE_DATASOURCE, DISABLED); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); + emrsClient.getJobRunResultCalled(0); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(0); @@ -479,7 +370,6 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataService() { @Override @@ -493,10 +383,12 @@ public void updateIndexToManualRefresh( }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); + Assertions.assertFalse(FlintStreamingJobHouseKeeperTask.isRunning.get()); emrsClient.getJobRunResultCalled(0); emrsClient.startJobRunCalled(0); @@ -511,70 +403,40 @@ public void updateIndexToManualRefresh( @Test @SneakyThrows public void testStreamingJobHouseKeeperMultipleTimesWhenDataSourceDisabled() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); changeDataSourceStatus(MYGLUE_DATASOURCE, DISABLED); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.ACTIVE); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("false", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.ACTIVE); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("false", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(3); emrsClient.startJobRunCalled(0); @@ -588,16 +450,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { Thread thread2 = new Thread(flintStreamingJobHouseKeeperTask); thread2.start(); thread2.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.ACTIVE); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("false", options.get("auto_refresh")); - }); + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.ACTIVE); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("false", options.get("auto_refresh")); + }); // No New Calls and Errors emrsClient.cancelJobRunCalled(3); @@ -613,70 +474,40 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { @SneakyThrows @Test public void testRunStreamingJobHouseKeeperWhenDataSourceIsDeleted() { - MockFlintIndex SKIPPING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_skipping_index", - FlintIndexType.SKIPPING, - "ALTER SKIPPING INDEX ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex COVERING = - new MockFlintIndex( - client, - "flint_my_glue_mydb_http_logs_covering_index", - FlintIndexType.COVERING, - "ALTER INDEX covering ON my_glue.mydb.http_logs WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\")"); - MockFlintIndex MV = - new MockFlintIndex( - client, - "flint_my_glue_mydb_mv", - FlintIndexType.MATERIALIZED_VIEW, - "ALTER MATERIALIZED VIEW my_glue.mydb.mv WITH (auto_refresh=false," - + " incremental_refresh=true, output_mode=\"complete\") "); + ImmutableList mockFlintIndices = getMockFlintIndices(); Map indexJobMapping = new HashMap<>(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - INDEX.createIndex(); - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); - indexJobMapping.put(INDEX, flintIndexJob); - HashMap existingOptions = new HashMap<>(); - existingOptions.put("auto_refresh", "true"); - // Making Index Auto Refresh - INDEX.updateIndexOptions(existingOptions, false); - flintIndexJob.refreshing(); - }); + mockFlintIndices.forEach( + INDEX -> { + INDEX.createIndex(); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + indexJobMapping.put(INDEX, flintIndexJob); + HashMap existingOptions = new HashMap<>(); + existingOptions.put("auto_refresh", "true"); + // Making Index Auto Refresh + INDEX.updateIndexOptions(existingOptions, false); + flintIndexJob.refreshing(); + }); this.dataSourceService.deleteDataSource(MYGLUE_DATASOURCE); - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - super.getJobRunResult(applicationId, jobId); - JobRun jobRun = new JobRun(); - jobRun.setState("cancelled"); - return new GetJobRunResult().withJobRun(jobRun); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + LocalEMRSClient emrsClient = getCancelledLocalEmrsClient(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.DELETED); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("true", options.get("auto_refresh")); - }); + + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.DELETED); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("true", options.get("auto_refresh")); + }); emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(3); emrsClient.startJobRunCalled(0); @@ -690,16 +521,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { Thread thread2 = new Thread(flintStreamingJobHouseKeeperTask); thread2.start(); thread2.join(); - ImmutableList.of(SKIPPING, COVERING, MV) - .forEach( - INDEX -> { - MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); - flintIndexJob.assertState(FlintIndexState.DELETED); - Map mappings = INDEX.getIndexMappings(); - Map meta = (HashMap) mappings.get("_meta"); - Map options = (Map) meta.get("options"); - Assertions.assertEquals("true", options.get("auto_refresh")); - }); + mockFlintIndices.forEach( + INDEX -> { + MockFlintSparkJob flintIndexJob = indexJobMapping.get(INDEX); + flintIndexJob.assertState(FlintIndexState.DELETED); + Map mappings = INDEX.getIndexMappings(); + Map meta = (HashMap) mappings.get("_meta"); + Map options = (Map) meta.get("options"); + Assertions.assertEquals("true", options.get("auto_refresh")); + }); // No New Calls and Errors emrsClient.cancelJobRunCalled(3); emrsClient.getJobRunResultCalled(3); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 045de66d0a..aade6ff63b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -24,35 +24,32 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.client.Client; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @ExtendWith(MockitoExtension.class) class IndexDMLHandlerTest { - @Mock private EMRServerlessClient emrServerlessClient; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private FlintIndexMetadataService flintIndexMetadataService; - @Mock private StateStore stateStore; - @Mock private Client client; + @Mock private IndexDMLResultStorageService indexDMLResultStorageService; + @Mock private FlintIndexOpFactory flintIndexOpFactory; @Test public void getResponseFromExecutor() { - JSONObject result = - new IndexDMLHandler(null, null, null, null, null).getResponseFromExecutor(null); + JSONObject result = new IndexDMLHandler(null, null, null, null).getResponseFromExecutor(null); assertEquals("running", result.getString(STATUS_FIELD)); assertEquals("", result.getString(ERROR_FIELD)); @@ -62,11 +59,10 @@ public void getResponseFromExecutor() { public void testWhenIndexDetailsAreNotFound() { IndexDMLHandler indexDMLHandler = new IndexDMLHandler( - emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, - client); + indexDMLResultStorageService, + flintIndexOpFactory); DispatchQueryRequest dispatchQueryRequest = new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -94,8 +90,10 @@ public void testWhenIndexDetailsAreNotFound() { .build(); Mockito.when(flintIndexMetadataService.getFlintIndexMetadata(any())) .thenReturn(new HashMap<>()); + DispatchQueryResponse dispatchQueryResponse = indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); + Assertions.assertNotNull(dispatchQueryResponse.getQueryId()); } @@ -104,11 +102,10 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { FlintIndexMetadata flintIndexMetadata = mock(FlintIndexMetadata.class); IndexDMLHandler indexDMLHandler = new IndexDMLHandler( - emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, - client); + indexDMLResultStorageService, + flintIndexOpFactory); DispatchQueryRequest dispatchQueryRequest = new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -139,6 +136,7 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { flintMetadataMap.put(indexQueryDetails.openSearchIndexName(), flintIndexMetadata); when(flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName())) .thenReturn(flintMetadataMap); + indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 8de5fe3fb4..36264e49c6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -54,7 +54,6 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.client.Client; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; @@ -72,8 +71,9 @@ import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -86,13 +86,10 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private FlintIndexMetadataService flintIndexMetadataService; - - @Mock(answer = RETURNS_DEEP_STUBS) - private Client openSearchClient; - @Mock private SessionManager sessionManager; - @Mock private LeaseManager leaseManager; + @Mock private IndexDMLResultStorageService indexDMLResultStorageService; + @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -100,8 +97,6 @@ public class SparkQueryDispatcherTest { @Mock(answer = RETURNS_DEEP_STUBS) private Statement statement; - @Mock private StateStore stateStore; - private SparkQueryDispatcher sparkQueryDispatcher; private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @@ -114,13 +109,14 @@ void setUp() { new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataService, - openSearchClient, sessionManager, leaseManager, - stateStore, + indexDMLResultStorageService, + flintIndexOpFactory, emrServerlessClientFactory); sparkQueryDispatcher = new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); } @Test @@ -405,7 +401,6 @@ void testDispatchIndexQuery() { tags, true, "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) @@ -420,6 +415,7 @@ void testDispatchIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -513,6 +509,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -563,6 +560,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -613,6 +611,7 @@ void testDispatchMaterializedViewQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -659,6 +658,7 @@ void testDispatchShowMVQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -751,6 +751,7 @@ void testDispatchDescribeIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -962,7 +963,9 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); Assertions.assertEquals( new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 5755d03baa..b3dc65a5fe 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -13,6 +13,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -22,6 +23,7 @@ public class FlintIndexOpTest { @Mock private StateStore mockStateStore; + @Mock private EMRServerlessClientFactory mockEmrServerlessClientFactory; @Test public void testApplyWithTransitioningStateFailure() { @@ -42,7 +44,8 @@ public void testApplyWithTransitioningStateFailure() { .thenReturn(Optional.of(fakeModel)); when(mockStateStore.updateState(any(), any(), any(), any())) .thenThrow(new RuntimeException("Transitioning state failed")); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = + new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -70,7 +73,8 @@ public void testApplyWithCommitFailure() { .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 3)); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = + new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -98,7 +102,8 @@ public void testApplyWithRollBackFailure() { .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenThrow(new RuntimeException("Rollback failure")); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = + new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -107,8 +112,11 @@ public void testApplyWithRollBackFailure() { static class TestFlintIndexOp extends FlintIndexOp { - public TestFlintIndexOp(StateStore stateStore, String datasourceName) { - super(stateStore, datasourceName); + public TestFlintIndexOp( + StateStore stateStore, + String datasourceName, + EMRServerlessClientFactory emrServerlessClientFactory) { + super(stateStore, datasourceName, emrServerlessClientFactory); } @Override From 3c894b750f915ac82259e9177acdec6907722f00 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 9 May 2024 15:30:15 -0700 Subject: [PATCH 2/4] Introduce FlintIndexStateModelService (#2658) * Introduce FlintIndexStateModelService Signed-off-by: Tomoyuki Morita * Reformat Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita (cherry picked from commit df1c04a3b83e53460848864aace352186e0dc537) --- .../spark/dispatcher/BatchQueryHandler.java | 4 +- .../dispatcher/StreamingQueryHandler.java | 2 - .../statestore/OpenSearchStateStoreUtil.java | 20 +++++ .../flint/FlintIndexStateModelService.java | 26 +++++++ ...OpenSearchFlintIndexStateModelService.java | 50 ++++++++++++ .../spark/flint/operation/FlintIndexOp.java | 23 +++--- .../flint/operation/FlintIndexOpAlter.java | 6 +- .../flint/operation/FlintIndexOpCancel.java | 6 +- .../flint/operation/FlintIndexOpDrop.java | 7 +- .../flint/operation/FlintIndexOpFactory.java | 15 ++-- .../flint/operation/FlintIndexOpVacuum.java | 6 +- .../config/AsyncExecutorServiceModule.java | 11 ++- .../AsyncQueryExecutorServiceSpec.java | 8 +- .../AsyncQueryGetResultSpecTest.java | 3 +- .../asyncquery/IndexQuerySpecAlterTest.java | 48 ++++++++---- .../spark/asyncquery/IndexQuerySpecTest.java | 43 +++++++---- .../asyncquery/IndexQuerySpecVacuumTest.java | 3 +- .../asyncquery/model/MockFlintSparkJob.java | 41 ++++------ .../FlintStreamingJobHouseKeeperTaskTest.java | 21 +++-- .../dispatcher/SparkQueryDispatcherTest.java | 2 + .../OpenSearchStateStoreUtilTest.java | 20 +++++ ...SearchFlintIndexStateModelServiceTest.java | 77 +++++++++++++++++++ .../flint/operation/FlintIndexOpTest.java | 32 ++++---- 23 files changed, 356 insertions(+), 118 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtil.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtilTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index e9356e5bed..c5cbc1e539 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -30,8 +30,8 @@ @RequiredArgsConstructor public class BatchQueryHandler extends AsyncQueryHandler { - private final EMRServerlessClient emrServerlessClient; - private final JobExecutionResponseReader jobExecutionResponseReader; + protected final EMRServerlessClient emrServerlessClient; + protected final JobExecutionResponseReader jobExecutionResponseReader; protected final LeaseManager leaseManager; @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 8170b41c66..08c10e04cc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -28,14 +28,12 @@ /** Handle Streaming Query. */ public class StreamingQueryHandler extends BatchQueryHandler { - private final EMRServerlessClient emrServerlessClient; public StreamingQueryHandler( EMRServerlessClient emrServerlessClient, JobExecutionResponseReader jobExecutionResponseReader, LeaseManager leaseManager) { super(emrServerlessClient, jobExecutionResponseReader, leaseManager); - this.emrServerlessClient = emrServerlessClient; } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtil.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtil.java new file mode 100644 index 0000000000..da9d166fcf --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtil.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME; + +import java.util.Locale; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class OpenSearchStateStoreUtil { + + public static String getIndexName(String datasourceName) { + return String.format( + "%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName.toLowerCase(Locale.ROOT)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java new file mode 100644 index 0000000000..a00056fd53 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import java.util.Optional; + +/** + * Abstraction over flint index state storage. Flint index state will maintain the status of each + * flint index. + */ +public interface FlintIndexStateModelService { + FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, String datasourceName); + + Optional getFlintIndexStateModel(String id, String datasourceName); + + FlintIndexStateModel updateFlintIndexState( + FlintIndexStateModel flintIndexStateModel, + FlintIndexState flintIndexState, + String datasourceName); + + boolean deleteFlintIndexStateModel(String id, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java new file mode 100644 index 0000000000..2db3930821 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@RequiredArgsConstructor +public class OpenSearchFlintIndexStateModelService implements FlintIndexStateModelService { + private final StateStore stateStore; + + @Override + public FlintIndexStateModel updateFlintIndexState( + FlintIndexStateModel flintIndexStateModel, + FlintIndexState flintIndexState, + String datasourceName) { + return stateStore.updateState( + flintIndexStateModel, + flintIndexState, + FlintIndexStateModel::copyWithState, + OpenSearchStateStoreUtil.getIndexName(datasourceName)); + } + + @Override + public Optional getFlintIndexStateModel(String id, String datasourceName) { + return stateStore.get( + id, + FlintIndexStateModel::fromXContent, + OpenSearchStateStoreUtil.getIndexName(datasourceName)); + } + + @Override + public FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, String datasourceName) { + return stateStore.create( + flintIndexStateModel, + FlintIndexStateModel::copy, + OpenSearchStateStoreUtil.getIndexName(datasourceName)); + } + + @Override + public boolean deleteFlintIndexStateModel(String id, String datasourceName) { + return stateStore.delete(id, OpenSearchStateStoreUtil.getIndexName(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index edfd0aace2..0b1ccc988e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -6,9 +6,6 @@ package org.opensearch.sql.spark.flint.operation; import static org.opensearch.sql.spark.client.EmrServerlessClientImpl.GENERIC_INTERNAL_SERVER_ERROR_MESSAGE; -import static org.opensearch.sql.spark.execution.statestore.StateStore.deleteFlintIndexState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getFlintIndexState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateFlintIndexState; import com.amazonaws.services.emrserverless.model.ValidationException; import java.util.Locale; @@ -22,17 +19,17 @@ import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Flint Index Operation. */ @RequiredArgsConstructor public abstract class FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; private final String datasourceName; private final EMRServerlessClientFactory emrServerlessClientFactory; @@ -57,8 +54,10 @@ public void apply(FlintIndexMetadata metadata) { } catch (Throwable e) { LOG.error("Rolling back transient log due to transaction operation failure", e); try { - updateFlintIndexState(stateStore, datasourceName) - .apply(transitionedFlintIndexStateModel, initialFlintIndexStateModel.getIndexState()); + flintIndexStateModelService.updateFlintIndexState( + transitionedFlintIndexStateModel, + initialFlintIndexStateModel.getIndexState(), + datasourceName); } catch (Exception ex) { LOG.error("Failed to rollback transient log", ex); } @@ -70,7 +69,7 @@ public void apply(FlintIndexMetadata metadata) { @NotNull private FlintIndexStateModel getFlintIndexStateModel(String latestId) { Optional flintIndexOptional = - getFlintIndexState(stateStore, datasourceName).apply(latestId); + flintIndexStateModelService.getFlintIndexStateModel(latestId, datasourceName); if (flintIndexOptional.isEmpty()) { String errorMsg = String.format(Locale.ROOT, "no state found. docId: %s", latestId); LOG.error(errorMsg); @@ -111,7 +110,8 @@ private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flint FlintIndexState transitioningState = transitioningState(); try { flintIndex = - updateFlintIndexState(stateStore, datasourceName).apply(flintIndex, transitioningState()); + flintIndexStateModelService.updateFlintIndexState( + flintIndex, transitioningState(), datasourceName); } catch (Exception e) { String errorMsg = String.format(Locale.ROOT, "Moving to transition state:%s failed.", transitioningState); @@ -127,9 +127,10 @@ private void commit(FlintIndexStateModel flintIndex) { try { if (stableState == FlintIndexState.NONE) { LOG.info("Deleting index state with docId: " + flintIndex.getLatestId()); - deleteFlintIndexState(stateStore, datasourceName).apply(flintIndex.getLatestId()); + flintIndexStateModelService.deleteFlintIndexStateModel( + flintIndex.getLatestId(), datasourceName); } else { - updateFlintIndexState(stateStore, datasourceName).apply(flintIndex, stableState); + flintIndexStateModelService.updateFlintIndexState(flintIndex, stableState, datasourceName); } } catch (Exception e) { String errorMsg = diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 31e33539a1..9955320253 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -10,11 +10,11 @@ import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** * Index Operation for Altering the flint index. Only handles alter operation when @@ -27,11 +27,11 @@ public class FlintIndexOpAlter extends FlintIndexOp { public FlintIndexOpAlter( FlintIndexOptions flintIndexOptions, - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClientFactory emrServerlessClientFactory, FlintIndexMetadataService flintIndexMetadataService) { - super(stateStore, datasourceName, emrServerlessClientFactory); + super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOptions = flintIndexOptions; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 0962e2a16b..02c8e39c66 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -9,20 +9,20 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Cancel refreshing job for refresh query when user clicks cancel button on UI. */ public class FlintIndexOpCancel extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); public FlintIndexOpCancel( - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClientFactory emrServerlessClientFactory) { - super(stateStore, datasourceName, emrServerlessClientFactory); + super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); } // Only in refreshing state, the job is cancellable in case of REFRESH query. diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 0f71b3bc70..6613c29870 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -9,19 +9,20 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +/** Operation to drop Flint index */ public class FlintIndexOpDrop extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); public FlintIndexOpDrop( - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClientFactory emrServerlessClientFactory) { - super(stateStore, datasourceName, emrServerlessClientFactory); + super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); } public boolean validate(FlintIndexState state) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java index 6fc2261ade..b102e43d59 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java @@ -9,34 +9,37 @@ import org.opensearch.client.Client; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; @RequiredArgsConstructor public class FlintIndexOpFactory { - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; private final Client client; private final FlintIndexMetadataService flintIndexMetadataService; private final EMRServerlessClientFactory emrServerlessClientFactory; public FlintIndexOpDrop getDrop(String datasource) { - return new FlintIndexOpDrop(stateStore, datasource, emrServerlessClientFactory); + return new FlintIndexOpDrop( + flintIndexStateModelService, datasource, emrServerlessClientFactory); } public FlintIndexOpAlter getAlter(FlintIndexOptions flintIndexOptions, String datasource) { return new FlintIndexOpAlter( flintIndexOptions, - stateStore, + flintIndexStateModelService, datasource, emrServerlessClientFactory, flintIndexMetadataService); } public FlintIndexOpVacuum getVacuum(String datasource) { - return new FlintIndexOpVacuum(stateStore, datasource, client, emrServerlessClientFactory); + return new FlintIndexOpVacuum( + flintIndexStateModelService, datasource, client, emrServerlessClientFactory); } public FlintIndexOpCancel getCancel(String datasource) { - return new FlintIndexOpCancel(stateStore, datasource, emrServerlessClientFactory); + return new FlintIndexOpCancel( + flintIndexStateModelService, datasource, emrServerlessClientFactory); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index 4287d9c7c9..ffd09e16a4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -11,10 +11,10 @@ import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Flint index vacuum operation. */ public class FlintIndexOpVacuum extends FlintIndexOp { @@ -25,11 +25,11 @@ public class FlintIndexOpVacuum extends FlintIndexOp { private final Client client; public FlintIndexOpVacuum( - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, Client client, EMRServerlessClientFactory emrServerlessClientFactory) { - super(stateStore, datasourceName, emrServerlessClientFactory); + super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); this.client = client; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 1d890ce346..dfc8e4042a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -30,7 +30,9 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.OpenSearchFlintIndexStateModelService; import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; @@ -96,12 +98,17 @@ public QueryHandlerFactory queryhandlerFactory( @Provides public FlintIndexOpFactory flintIndexOpFactory( - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, NodeClient client, FlintIndexMetadataServiceImpl flintIndexMetadataService, EMRServerlessClientFactory emrServerlessClientFactory) { return new FlintIndexOpFactory( - stateStore, client, flintIndexMetadataService, emrServerlessClientFactory); + flintIndexStateModelService, client, flintIndexMetadataService, emrServerlessClientFactory); + } + + @Provides + public FlintIndexStateModelService flintIndexStateModelService(StateStore stateStore) { + return new OpenSearchFlintIndexStateModelService(stateStore); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index b1c7f68388..84a2128821 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -66,7 +66,9 @@ import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.flint.OpenSearchFlintIndexStateModelService; import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; @@ -86,6 +88,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected StateStore stateStore; protected ClusterSettings clusterSettings; protected FlintIndexMetadataService flintIndexMetadataService; + protected FlintIndexStateModelService flintIndexStateModelService; @Override protected Collection> nodePlugins() { @@ -155,12 +158,13 @@ public void setup() { createIndexWithMappings(dm.getResultIndex(), loadResultIndexMappings()); createIndexWithMappings(otherDm.getResultIndex(), loadResultIndexMappings()); flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + flintIndexStateModelService = new OpenSearchFlintIndexStateModelService(stateStore); } protected FlintIndexOpFactory getFlintIndexOpFactory( EMRServerlessClientFactory emrServerlessClientFactory) { return new FlintIndexOpFactory( - stateStore, client, flintIndexMetadataService, emrServerlessClientFactory); + flintIndexStateModelService, client, flintIndexMetadataService, emrServerlessClientFactory); } @After @@ -222,7 +226,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new DefaultLeaseManager(pluginSettings, stateStore), new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore), new FlintIndexOpFactory( - stateStore, + flintIndexStateModelService, client, new FlintIndexMetadataServiceImpl(client), emrServerlessClientFactory), diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 10598d110c..6dcc2c17af 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -53,7 +53,8 @@ public class AsyncQueryGetResultSpecTest extends AsyncQueryExecutorServiceSpec { @Before public void doSetUp() { - mockIndexState = new MockFlintSparkJob(stateStore, mockIndex.latestId, MYS3_DATASOURCE); + mockIndexState = + new MockFlintSparkJob(flintIndexStateModelService, mockIndex.latestId, MYS3_DATASOURCE); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index ddefebcf77..d49e3883da 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -68,7 +68,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -135,7 +136,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -215,7 +217,8 @@ public CancelJobRunResult cancelJobRun( mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -277,7 +280,8 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -341,7 +345,8 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -414,7 +419,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -487,7 +493,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -554,7 +561,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -614,7 +622,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -676,7 +685,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -738,7 +748,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1. alter index @@ -797,7 +808,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, true); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1. alter index @@ -854,7 +866,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.updating(); // 1. alter index @@ -919,7 +932,8 @@ public CancelJobRunResult cancelJobRun( mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -982,7 +996,8 @@ public CancelJobRunResult cancelJobRun( mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index @@ -1046,7 +1061,8 @@ public CancelJobRunResult cancelJobRun( mockDS.updateIndexOptions(existingOptions, false); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.getLatestId(), MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.getLatestId(), MYS3_DATASOURCE); flintIndexJob.active(); // 1. alter index diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 864a87586f..09addccdbb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -294,7 +294,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1.drop index @@ -352,7 +353,8 @@ public CancelJobRunResult cancelJobRun( mockDS.createIndex(); // Mock index state in refresh state. MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1.drop index @@ -397,7 +399,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1. drop index @@ -441,7 +444,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // 1. drop index @@ -490,7 +494,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.active(); // 1. drop index @@ -536,7 +541,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.creating(); // 1. drop index @@ -582,7 +588,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); // 1. drop index CreateAsyncQueryResponse response = @@ -634,7 +641,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); flintIndexJob.deleting(); // 1. drop index @@ -679,7 +687,7 @@ public CancelJobRunResult cancelJobRun( mockDS.createIndex(); // Mock index state in refresh state. MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYGLUE_DATASOURCE); + new MockFlintSparkJob(flintIndexStateModelService, mockDS.latestId, MYGLUE_DATASOURCE); flintIndexJob.refreshing(); // 1.drop index @@ -752,7 +760,7 @@ public void concurrentRefreshJobLimitNotApplied() { COVERING.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, COVERING.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob(flintIndexStateModelService, COVERING.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // query with auto refresh @@ -777,7 +785,7 @@ public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { COVERING.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, COVERING.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob(flintIndexStateModelService, COVERING.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // query with auto_refresh = true. @@ -805,7 +813,7 @@ public void concurrentRefreshJobLimitAppliedToRefresh() { COVERING.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, COVERING.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob(flintIndexStateModelService, COVERING.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); // query with auto_refresh = true. @@ -832,7 +840,7 @@ public void concurrentRefreshJobLimitNotAppliedToDDL() { COVERING.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, COVERING.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob(flintIndexStateModelService, COVERING.latestId, MYS3_DATASOURCE); flintIndexJob.refreshing(); CreateAsyncQueryResponse asyncQueryResponse = @@ -905,7 +913,8 @@ public GetJobRunResult getJobRunResult( mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); // 1. Submit REFRESH statement CreateAsyncQueryResponse response = @@ -948,7 +957,8 @@ public GetJobRunResult getJobRunResult( mockDS.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, mockDS.latestId, MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, mockDS.latestId, MYS3_DATASOURCE); // 1. Submit REFRESH statement CreateAsyncQueryResponse response = @@ -990,7 +1000,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockFlintIndex.createIndex(); // Mock index state MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, indexName + "_latest_id", MYS3_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, indexName + "_latest_id", MYS3_DATASOURCE); // 1. Submit REFRESH statement CreateAsyncQueryResponse response = diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index 76adddf89d..c9660c8d87 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -164,7 +164,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { mockDS.createIndex(); // Mock index state doc - MockFlintSparkJob flintIndexJob = new MockFlintSparkJob(stateStore, mockDS.latestId, "mys3"); + MockFlintSparkJob flintIndexJob = + new MockFlintSparkJob(flintIndexStateModelService, mockDS.latestId, "mys3"); flintIndexJob.transition(state); // Vacuum index diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 4cfdb6a9a9..4c58ea472f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -11,18 +11,19 @@ import java.util.Optional; import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; public class MockFlintSparkJob { private FlintIndexStateModel stateModel; - private StateStore stateStore; + private FlintIndexStateModelService flintIndexStateModelService; private String datasource; - public MockFlintSparkJob(StateStore stateStore, String latestId, String datasource) { + public MockFlintSparkJob( + FlintIndexStateModelService flintIndexStateModelService, String latestId, String datasource) { assertNotNull(latestId); - this.stateStore = stateStore; + this.flintIndexStateModelService = flintIndexStateModelService; this.datasource = datasource; stateModel = new FlintIndexStateModel( @@ -35,54 +36,42 @@ public MockFlintSparkJob(StateStore stateStore, String latestId, String datasour "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - stateModel = StateStore.createFlintIndexState(stateStore, datasource).apply(stateModel); + stateModel = flintIndexStateModelService.createFlintIndexStateModel(stateModel, datasource); } public void transition(FlintIndexState newState) { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource).apply(stateModel, newState); + flintIndexStateModelService.updateFlintIndexState(stateModel, newState, datasource); } public void refreshing() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.REFRESHING); + transition(FlintIndexState.REFRESHING); } public void active() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.ACTIVE); + transition(FlintIndexState.ACTIVE); } public void creating() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.CREATING); + transition(FlintIndexState.CREATING); } public void updating() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.UPDATING); + transition(FlintIndexState.UPDATING); } public void deleting() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.DELETING); + transition(FlintIndexState.DELETING); } public void deleted() { - stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.DELETED); + transition(FlintIndexState.DELETED); } public void assertState(FlintIndexState expected) { Optional stateModelOpt = - StateStore.getFlintIndexState(stateStore, datasource).apply(stateModel.getId()); - assertTrue((stateModelOpt.isPresent())); + flintIndexStateModelService.getFlintIndexStateModel(stateModel.getId(), datasource); + assertTrue(stateModelOpt.isPresent()); assertEquals(expected, stateModelOpt.get().getIndexState()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index 6bcf9c6308..aa4684811f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -39,7 +39,8 @@ public void testStreamingJobHouseKeeperWhenDataSourceDisabled() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -117,7 +118,8 @@ public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -164,7 +166,8 @@ public void testSimulateConcurrentJobHouseKeeperExecution() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -213,7 +216,8 @@ public void testStreamingJobClearnerWhenDataSourceIsDeleted() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -260,7 +264,8 @@ public void testStreamingJobHouseKeeperWhenDataSourceIsNeitherDisabledNorDeleted INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -409,7 +414,8 @@ public void testStreamingJobHouseKeeperMultipleTimesWhenDataSourceDisabled() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); @@ -480,7 +486,8 @@ public void testRunStreamingJobHouseKeeperWhenDataSourceIsDeleted() { INDEX -> { INDEX.createIndex(); MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(stateStore, INDEX.getLatestId(), MYGLUE_DATASOURCE); + new MockFlintSparkJob( + flintIndexStateModelService, INDEX.getLatestId(), MYGLUE_DATASOURCE); indexJobMapping.put(INDEX, flintIndexJob); HashMap existingOptions = new HashMap<>(); existingOptions.put("auto_refresh", "true"); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 36264e49c6..92fd6b3d0a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -301,6 +301,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -705,6 +706,7 @@ void testRefreshIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtilTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtilTest.java new file mode 100644 index 0000000000..318080ff2d --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStateStoreUtilTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.Assert.assertEquals; + +import org.junit.jupiter.api.Test; + +public class OpenSearchStateStoreUtilTest { + + @Test + void getIndexName() { + String result = OpenSearchStateStoreUtil.getIndexName("DATASOURCE"); + + assertEquals(".query_execution_request_datasource", result); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java new file mode 100644 index 0000000000..aebc136b93 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchFlintIndexStateModelServiceTest { + + public static final String DATASOURCE = "DATASOURCE"; + public static final String ID = "ID"; + + @Mock StateStore mockStateStore; + @Mock FlintIndexStateModel flintIndexStateModel; + @Mock FlintIndexState flintIndexState; + @Mock FlintIndexStateModel responseFlintIndexStateModel; + + @InjectMocks OpenSearchFlintIndexStateModelService openSearchFlintIndexStateModelService; + + @Test + void updateFlintIndexState() { + when(mockStateStore.updateState(any(), any(), any(), any())) + .thenReturn(responseFlintIndexStateModel); + + FlintIndexStateModel result = + openSearchFlintIndexStateModelService.updateFlintIndexState( + flintIndexStateModel, flintIndexState, DATASOURCE); + + assertEquals(responseFlintIndexStateModel, result); + } + + @Test + void getFlintIndexStateModel() { + when(mockStateStore.get(any(), any(), any())) + .thenReturn(Optional.of(responseFlintIndexStateModel)); + + Optional result = + openSearchFlintIndexStateModelService.getFlintIndexStateModel("ID", DATASOURCE); + + assertEquals(responseFlintIndexStateModel, result.get()); + } + + @Test + void createFlintIndexStateModel() { + when(mockStateStore.create(any(), any(), any())).thenReturn(responseFlintIndexStateModel); + + FlintIndexStateModel result = + openSearchFlintIndexStateModelService.createFlintIndexStateModel( + flintIndexStateModel, DATASOURCE); + + assertEquals(responseFlintIndexStateModel, result); + } + + @Test + void deleteFlintIndexStateModel() { + when(mockStateStore.delete(any(), any())).thenReturn(true); + + boolean result = + openSearchFlintIndexStateModelService.deleteFlintIndexStateModel(ID, DATASOURCE); + + assertTrue(result); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index b3dc65a5fe..6c2a3a81a4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -1,10 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.spark.flint.operation; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.util.Optional; import org.junit.jupiter.api.Assertions; @@ -14,15 +18,15 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; @ExtendWith(MockitoExtension.class) public class FlintIndexOpTest { - @Mock private StateStore mockStateStore; + @Mock private FlintIndexStateModelService flintIndexStateModelService; @Mock private EMRServerlessClientFactory mockEmrServerlessClientFactory; @Test @@ -40,12 +44,12 @@ public void testApplyWithTransitioningStateFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenThrow(new RuntimeException("Transitioning state failed")); FlintIndexOp flintIndexOp = - new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); + new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -67,14 +71,14 @@ public void testApplyWithCommitFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 3)); FlintIndexOp flintIndexOp = - new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); + new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -96,14 +100,14 @@ public void testApplyWithRollBackFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenThrow(new RuntimeException("Rollback failure")); FlintIndexOp flintIndexOp = - new TestFlintIndexOp(mockStateStore, "myS3", mockEmrServerlessClientFactory); + new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -113,10 +117,10 @@ public void testApplyWithRollBackFailure() { static class TestFlintIndexOp extends FlintIndexOp { public TestFlintIndexOp( - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClientFactory emrServerlessClientFactory) { - super(stateStore, datasourceName, emrServerlessClientFactory); + super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); } @Override From ddcf96f4190bbc08be243cabd4bf54959d2fc819 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Fri, 10 May 2024 12:54:16 -0700 Subject: [PATCH 3/4] Add comments to async query handlers (#2657) * Add comments to query handlers Signed-off-by: Tomoyuki Morita * Reformat Signed-off-by: Tomoyuki Morita * Fix comments Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita (cherry picked from commit 05a2f66a0af2baab747aa0afd437441dbd5efc67) --- .../opensearch/sql/spark/dispatcher/BatchQueryHandler.java | 4 ++++ .../opensearch/sql/spark/dispatcher/IndexDMLHandler.java | 7 ++++++- .../sql/spark/dispatcher/InteractiveQueryHandler.java | 6 ++++++ .../sql/spark/dispatcher/RefreshQueryHandler.java | 5 ++++- .../sql/spark/dispatcher/StreamingQueryHandler.java | 5 ++++- .../sql/spark/flint/IndexDMLResultStorageService.java | 3 +++ 6 files changed, 27 insertions(+), 3 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index c5cbc1e539..d06153bf79 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -28,6 +28,10 @@ import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +/** + * The handler for batch query. With batch query, queries are executed as single batch. The queries + * are sent along with job execution request ({@link StartJobRequest}) to spark. + */ @RequiredArgsConstructor public class BatchQueryHandler extends AsyncQueryHandler { protected final EMRServerlessClient emrServerlessClient; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index dfd5316f6c..b2bb590c1e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -31,7 +31,12 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.response.JobExecutionResponseReader; -/** Handle Index DML query. includes * DROP * ALT? */ +/** + * The handler for Index DML (Data Manipulation Language) query. Handles DROP/ALTER/VACUUM operation + * for flint indices. It will stop streaming query job as needed (e.g. when the flint index is + * automatically updated by a streaming query, the streaming query is stopped when the index is + * dropped) + */ @RequiredArgsConstructor public class IndexDMLHandler extends AsyncQueryHandler { private static final Logger LOG = LogManager.getLogger(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 7602988d26..7475c5a7ae 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -35,6 +35,12 @@ import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +/** + * The handler for interactive query. With interactive query, a session will be first established + * and then the session will be reused for the following queries(statements). Session is an + * abstraction of spark job, and once the job is started, the job will continuously poll the + * statements and execute query specified in it. + */ @RequiredArgsConstructor public class InteractiveQueryHandler extends AsyncQueryHandler { private final SessionManager sessionManager; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index aeb5c1b35f..edb0a3f507 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -20,7 +20,10 @@ import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; -/** Handle Refresh Query. */ +/** + * The handler for refresh query. Refresh query is one time query request to refresh(update) flint + * index, and new job is submitted to Spark. + */ public class RefreshQueryHandler extends BatchQueryHandler { private final FlintIndexMetadataService flintIndexMetadataService; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 08c10e04cc..4a9b1ce5d5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -26,7 +26,10 @@ import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; -/** Handle Streaming Query. */ +/** + * The handler for streaming query. Streaming query is a job to continuously update flint index. + * Once started, the job can be stopped by IndexDML query. + */ public class StreamingQueryHandler extends BatchQueryHandler { public StreamingQueryHandler( diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java index 4a046564f5..31d4be511e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java @@ -7,6 +7,9 @@ import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; +/** + * Abstraction over the IndexDMLResult storage. It stores the result of IndexDML query execution. + */ public interface IndexDMLResultStorageService { IndexDMLResult createIndexDMLResult(IndexDMLResult result, String datasourceName); } From b48d36fee82278964c434a810232fa2797a11fd8 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Wed, 15 May 2024 11:26:36 -0700 Subject: [PATCH 4/4] Extract SessionStorageService and StatementStorageService (#2665) * Extract SessionStorageService and StatementStorageService Signed-off-by: Tomoyuki Morita * Reformat Signed-off-by: Tomoyuki Morita * Add copyright comment Signed-off-by: Tomoyuki Morita * Add comments and remove unused methods Signed-off-by: Tomoyuki Morita * Remove unneeded imports Signed-off-by: Tomoyuki Morita * Fix code format issue Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita (cherry picked from commit 1985459b7979ca6d1d9cae0b2c04851e6657f5af) --- .../execution/session/InteractiveSession.java | 22 +- .../execution/session/SessionManager.java | 27 +-- .../spark/execution/statement/Statement.java | 17 +- .../OpenSearchSessionStorageService.java | 41 ++++ .../OpenSearchStatementStorageService.java | 41 ++++ .../statestore/SessionStorageService.java | 21 ++ .../execution/statestore/StateStore.java | 81 ------- .../statestore/StatementStorageService.java | 24 ++ .../config/AsyncExecutorServiceModule.java | 20 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 12 +- .../AsyncQueryExecutorServiceSpec.java | 32 ++- .../AsyncQueryGetResultSpecTest.java | 6 +- .../session/InteractiveSessionTest.java | 147 ++++-------- .../execution/session/SessionManagerTest.java | 17 +- .../execution/session/SessionTestUtil.java | 26 +++ .../session/TestEMRServerlessClient.java | 51 ++++ .../execution/statement/StatementTest.java | 221 +++++++----------- 17 files changed, 423 insertions(+), 383 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/session/TestEMRServerlessClient.java diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 2363615a7d..f08ef4f489 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -10,8 +10,6 @@ import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; import static org.opensearch.sql.spark.execution.session.SessionState.FAIL; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -24,7 +22,8 @@ import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.TimeProvider; @@ -41,7 +40,8 @@ public class InteractiveSession implements Session { public static final String SESSION_ID_TAG_KEY = "sid"; private final SessionId sessionId; - private final StateStore stateStore; + private final SessionStorageService sessionStorageService; + private final StatementStorageService statementStorageService; private final EMRServerlessClient serverlessClient; private SessionModel sessionModel; // the threshold of elapsed time in milliseconds before we say a session is stale @@ -64,7 +64,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); + sessionStorageService.createSession(sessionModel, sessionModel.getDatasourceName()); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -76,7 +76,7 @@ public void open(CreateSessionRequest createSessionRequest) { @Override public void close() { Optional model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -88,7 +88,7 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { Optional model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -101,7 +101,7 @@ public StatementId submit(QueryRequest request) { .sessionId(sessionId) .applicationId(sessionModel.getApplicationId()) .jobId(sessionModel.getJobId()) - .stateStore(stateStore) + .statementStorageService(statementStorageService) .statementId(statementId) .langType(LangType.SQL) .datasourceName(sessionModel.getDatasourceName()) @@ -124,8 +124,8 @@ public StatementId submit(QueryRequest request) { @Override public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) - .apply(stID.getId()) + return statementStorageService + .getStatement(stID.getId(), sessionModel.getDatasourceName()) .map( model -> Statement.builder() @@ -136,7 +136,7 @@ public Optional get(StatementId stID) { .langType(model.getLangType()) .query(model.getQuery()) .queryId(model.getQueryId()) - .stateStore(stateStore) + .statementStorageService(statementStorageService) .statementModel(model) .build()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index e441492c20..f8d429dd38 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -9,9 +9,11 @@ import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; import java.util.Optional; +import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.utils.RealTimeProvider; /** @@ -19,25 +21,19 @@ * *

todo. add Session cache and Session sweeper. */ +@RequiredArgsConstructor public class SessionManager { - private final StateStore stateStore; + private final SessionStorageService sessionStorageService; + private final StatementStorageService statementStorageService; private final EMRServerlessClientFactory emrServerlessClientFactory; - private Settings settings; - - public SessionManager( - StateStore stateStore, - EMRServerlessClientFactory emrServerlessClientFactory, - Settings settings) { - this.stateStore = stateStore; - this.emrServerlessClientFactory = emrServerlessClientFactory; - this.settings = settings; - } + private final Settings settings; public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) - .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrServerlessClientFactory.getClient()) .build(); session.open(request); @@ -64,12 +60,13 @@ public Session createSession(CreateSessionRequest request) { */ public Optional getSession(SessionId sid, String dataSourceName) { Optional model = - StateStore.getSession(stateStore, dataSourceName).apply(sid.getSessionId()); + sessionStorageService.getSession(sid.getSessionId(), dataSourceName); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) - .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrServerlessClientFactory.getClient()) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 94c1f79511..cab045726c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -6,9 +6,6 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import lombok.Builder; import lombok.Getter; @@ -18,7 +15,7 @@ import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.rest.model.LangType; /** Statement represent query to execute in session. One statement map to one session. */ @@ -35,7 +32,7 @@ public class Statement { private final String datasourceName; private final String query; private final String queryId; - private final StateStore stateStore; + private final StatementStorageService statementStorageService; @Setter private StatementModel statementModel; @@ -52,7 +49,7 @@ public void open() { datasourceName, query, queryId); - statementModel = createStatement(stateStore, datasourceName).apply(statementModel); + statementModel = statementStorageService.createStatement(statementModel, datasourceName); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -76,8 +73,8 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore, statementModel.getDatasourceName()) - .apply(this.statementModel, StatementState.CANCELLED); + statementStorageService.updateStatementState( + statementModel, StatementState.CANCELLED, statementModel.getDatasourceName()); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -85,8 +82,8 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore, statementModel.getDatasourceName()) - .apply(statementModel.getId()) + statementStorageService + .getStatement(statementModel.getId(), statementModel.getDatasourceName()) .orElse(this.statementModel); String errorMsg = String.format( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java new file mode 100644 index 0000000000..cfff219eaa --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; + +@RequiredArgsConstructor +public class OpenSearchSessionStorageService implements SessionStorageService { + + private final StateStore stateStore; + + @Override + public SessionModel createSession(SessionModel sessionModel, String datasourceName) { + return stateStore.create( + sessionModel, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public Optional getSession(String id, String datasourceName) { + return stateStore.get( + id, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public SessionModel updateSessionState( + SessionModel sessionModel, SessionState sessionState, String datasourceName) { + return stateStore.updateState( + sessionModel, + sessionState, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java new file mode 100644 index 0000000000..b218490d6a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +@RequiredArgsConstructor +public class OpenSearchStatementStorageService implements StatementStorageService { + + private final StateStore stateStore; + + @Override + public StatementModel createStatement(StatementModel statementModel, String datasourceName) { + return stateStore.create( + statementModel, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public Optional getStatement(String id, String datasourceName) { + return stateStore.get( + id, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public StatementModel updateStatementState( + StatementModel oldStatementModel, StatementState statementState, String datasourceName) { + return stateStore.updateState( + oldStatementModel, + statementState, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java new file mode 100644 index 0000000000..43472b567c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.util.Optional; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; + +/** Interface for accessing {@link SessionModel} data storage. */ +public interface SessionStorageService { + + SessionModel createSession(SessionModel sessionModel, String datasourceName); + + Optional getSession(String id, String datasourceName); + + SessionModel updateSessionState( + SessionModel sessionModel, SessionState sessionState, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index e50a2837d9..3de83b2f3e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -49,7 +49,6 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.session.SessionType; @@ -250,55 +249,6 @@ private String loadConfigFromResource(String fileName) throws IOException { return IOUtils.toString(fileStream, StandardCharsets.UTF_8); } - /** Helper Functions */ - public static Function createStatement( - StateStore stateStore, String datasourceName) { - return (st) -> - stateStore.create( - st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function> getStatement( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, StatementModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static BiFunction updateStatementState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - StatementModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function createSession( - StateStore stateStore, String datasourceName) { - return (session) -> - stateStore.create( - session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function> getSession( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static BiFunction updateSessionState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - SessionModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - public static Function createJobMetaData( StateStore stateStore, String datasourceName) { return (jobMetadata) -> @@ -341,37 +291,6 @@ public static Supplier activeSessionsCount(StateStore stateStore, String d DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> getFlintIndexState( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, - FlintIndexStateModel::fromXContent, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function createFlintIndexState( - StateStore stateStore, String datasourceName) { - return (st) -> - stateStore.create( - st, FlintIndexStateModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - /** - * @param stateStore index state store - * @param datasourceName data source name - * @return function that accepts index state doc ID and perform the deletion - */ - public static Function deleteFlintIndexState( - StateStore stateStore, String datasourceName) { - return (docId) -> stateStore.delete(docId, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function createIndexDMLResult( - StateStore stateStore, String indexName) { - return (result) -> stateStore.create(result, IndexDMLResult::copy, indexName); - } - public static Supplier activeRefreshJobCount(StateStore stateStore, String datasourceName) { return () -> stateStore.count( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java new file mode 100644 index 0000000000..0f550eba7c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.util.Optional; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +/** + * Interface for accessing {@link StatementModel} data storage. {@link StatementModel} is an + * abstraction over the query request within a Session. + */ +public interface StatementStorageService { + + StatementModel createStatement(StatementModel statementModel, String datasourceName); + + StatementModel updateStatementState( + StatementModel oldStatementModel, StatementState statementState, String datasourceName); + + Optional getStatement(String id, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index dfc8e4042a..6a33e6d5b6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -28,7 +28,11 @@ import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; @@ -119,10 +123,22 @@ public IndexDMLResultStorageService indexDMLResultStorageService( @Provides public SessionManager sessionManager( - StateStore stateStore, + SessionStorageService sessionStorageService, + StatementStorageService statementStorageService, EMRServerlessClientFactory emrServerlessClientFactory, Settings settings) { - return new SessionManager(stateStore, emrServerlessClientFactory, settings); + return new SessionManager( + sessionStorageService, statementStorageService, emrServerlessClientFactory, settings); + } + + @Provides + public SessionStorageService sessionStorageService(StateStore stateStore) { + return new OpenSearchSessionStorageService(stateStore); + } + + @Provides + public StatementStorageService statementStorageService(StateStore stateStore) { + return new OpenSearchStatementStorageService(stateStore); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f2d3bb1aa8..4dce252513 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -13,8 +13,6 @@ import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import com.google.common.collect.ImmutableMap; import java.util.HashMap; @@ -144,7 +142,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); Optional statementModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(response.getQueryId()); + statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -199,13 +197,13 @@ public void reuseSessionWhenCreateAsyncQuery() { .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); Optional firstModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(first.getQueryId()); + statementStorageService.getStatement(first.getQueryId(), MYS3_DATASOURCE); assertTrue(firstModel.isPresent()); assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); assertEquals(first.getQueryId(), firstModel.get().getQueryId()); Optional secondModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(second.getQueryId()); + statementStorageService.getStatement(second.getQueryId(), MYS3_DATASOURCE); assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); assertEquals(second.getQueryId(), secondModel.get().getQueryId()); @@ -295,7 +293,7 @@ public void withSessionCreateAsyncQueryFailed() { new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); Optional statementModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(response.getQueryId()); + statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -319,7 +317,7 @@ public void withSessionCreateAsyncQueryFailed() { .seqNo(submitted.getSeqNo()) .primaryTerm(submitted.getPrimaryTerm()) .build(); - updateStatementState(stateStore, MYS3_DATASOURCE).apply(mocked, StatementState.FAILED); + statementStorageService.updateStatementState(mocked, StatementState.FAILED, MYS3_DATASOURCE); AsyncQueryExecutionResponse asyncQueryResults = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 84a2128821..a8ae5fcb1a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -8,9 +8,7 @@ import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.DATASOURCE_URI_HOSTS_DENY_LIST; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_REFRESH_JOB_LIMIT_SETTING; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_LIMIT_SETTING; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil.getIndexName; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -63,7 +61,11 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; @@ -85,10 +87,12 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected org.opensearch.sql.common.setting.Settings pluginSettings; protected NodeClient client; protected DataSourceServiceImpl dataSourceService; - protected StateStore stateStore; protected ClusterSettings clusterSettings; protected FlintIndexMetadataService flintIndexMetadataService; protected FlintIndexStateModelService flintIndexStateModelService; + protected StateStore stateStore; + protected SessionStorageService sessionStorageService; + protected StatementStorageService statementStorageService; @Override protected Collection> nodePlugins() { @@ -159,6 +163,8 @@ public void setup() { createIndexWithMappings(otherDm.getResultIndex(), loadResultIndexMappings()); flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); flintIndexStateModelService = new OpenSearchFlintIndexStateModelService(stateStore); + sessionStorageService = new OpenSearchSessionStorageService(stateStore); + statementStorageService = new OpenSearchStatementStorageService(stateStore); } protected FlintIndexOpFactory getFlintIndexOpFactory( @@ -222,7 +228,11 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new QueryHandlerFactory( jobExecutionResponseReader, new FlintIndexMetadataServiceImpl(client), - new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + new SessionManager( + sessionStorageService, + statementStorageService, + emrServerlessClientFactory, + pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore), new FlintIndexOpFactory( @@ -234,7 +244,11 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, - new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + new SessionManager( + sessionStorageService, + statementStorageService, + emrServerlessClientFactory, + pluginSettings), queryHandlerFactory); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, @@ -341,7 +355,7 @@ public void setConcurrentRefreshJob(long limit) { int search(QueryBuilder query) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(DATASOURCE_TO_REQUEST_INDEX.apply(MYS3_DATASOURCE)); + searchRequest.indices(getIndexName(MYS3_DATASOURCE)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); @@ -351,9 +365,9 @@ int search(QueryBuilder query) { } void setSessionState(String sessionId, SessionState sessionState) { - Optional model = getSession(stateStore, MYS3_DATASOURCE).apply(sessionId); + Optional model = sessionStorageService.getSession(sessionId, MYS3_DATASOURCE); SessionModel updated = - updateSessionState(stateStore, MYS3_DATASOURCE).apply(model.get(), sessionState); + sessionStorageService.updateSessionState(model.get(), sessionState, MYS3_DATASOURCE); assertEquals(sessionState, updated.getSessionState()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 6dcc2c17af..bcce6e27c2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -8,7 +8,6 @@ import static org.opensearch.action.support.WriteRequest.RefreshPolicy.WAIT_UNTIL; import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.datasource.model.DataSourceMetadata.DEFAULT_RESULT_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import com.amazonaws.services.emrserverless.model.JobRunState; import com.google.common.collect.ImmutableList; @@ -30,7 +29,6 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexType; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; @@ -511,8 +509,8 @@ void emrJobWriteResultDoc(Map resultDoc) { /** Simulate EMR-S updates query_execution_request with state */ void emrJobUpdateStatementState(StatementState newState) { - StatementModel stmt = getStatement(stateStore, MYS3_DATASOURCE).apply(queryId).get(); - StateStore.updateStatementState(stateStore, MYS3_DATASOURCE).apply(stmt, newState); + StatementModel stmt = statementStorageService.getStatement(queryId, MYS3_DATASOURCE).get(); + statementStorageService.updateStatementState(stmt, newState, MYS3_DATASOURCE); } void emrJobUpdateJobState(JobRunState jobState) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 8fca190cd6..8aac451f82 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -5,14 +5,12 @@ package org.opensearch.sql.spark.execution.session; -import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; +import static org.opensearch.sql.spark.execution.session.SessionTestUtil.createSessionRequest; -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; import java.util.HashMap; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -21,30 +19,43 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; -import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.test.OpenSearchIntegTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ public class InteractiveSessionTest extends OpenSearchIntegTestCase { - private static final String DS_NAME = "mys3"; - private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); - public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; + private static final String indexName = + OpenSearchStateStoreUtil.getIndexName(TEST_DATASOURCE_NAME); private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; - private StateStore stateStore; + private SessionStorageService sessionStorageService; + private StatementStorageService statementStorageService; + private SessionManager sessionManager; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(client(), clusterService()); + StateStore stateStore = new StateStore(client(), clusterService()); + sessionStorageService = new OpenSearchSessionStorageService(stateStore); + statementStorageService = new OpenSearchStatementStorageService(stateStore); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + sessionManager = + new SessionManager( + sessionStorageService, + statementStorageService, + emrServerlessClientFactory, + sessionSetting()); } @After @@ -56,17 +67,17 @@ public void clean() { @Test public void openCloseSession() { - SessionId sessionId = SessionId.newSessionId(DS_NAME); + SessionId sessionId = SessionId.newSessionId(TEST_DATASOURCE_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrsClient) .build(); - // open session - TestSession testSession = testSession(session, stateStore); - testSession + SessionAssertions assertions = new SessionAssertions(session); + assertions .open(createSessionRequest()) .assertSessionState(NOT_STARTED) .assertAppId("appId") @@ -76,17 +87,18 @@ public void openCloseSession() { TEST_CLUSTER_NAME + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId.getSessionId()); // close session - testSession.close(); + assertions.close(); emrsClient.cancelJobRunCalled(1); } @Test public void openSessionFailedConflict() { - SessionId sessionId = SessionId.newSessionId(DS_NAME); + SessionId sessionId = SessionId.newSessionId(TEST_DATASOURCE_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); session.open(createSessionRequest()); @@ -94,7 +106,8 @@ public void openSessionFailedConflict() { InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); IllegalStateException exception = @@ -105,11 +118,12 @@ public void openSessionFailedConflict() { @Test public void closeNotExistSession() { - SessionId sessionId = SessionId.newSessionId(DS_NAME); + SessionId sessionId = SessionId.newSessionId(TEST_DATASOURCE_NAME); InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); session.open(createSessionRequest()); @@ -123,20 +137,16 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); - TestSession testSession = testSession(session, stateStore); - testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + new SessionAssertions(session) + .assertSessionState(NOT_STARTED) + .assertAppId("appId") + .assertJobId("jobId"); } @Test public void sessionManagerGetSession() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - SessionManager sessionManager = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); @@ -146,103 +156,44 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - SessionManager sessionManager = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); - Optional managerSession = sessionManager.getSession(SessionId.newSessionId("no-exist")); assertTrue(managerSession.isEmpty()); } @RequiredArgsConstructor - static class TestSession { + class SessionAssertions { private final Session session; - private final StateStore stateStore; - - public static TestSession testSession(Session session, StateStore stateStore) { - return new TestSession(session, stateStore); - } - public TestSession assertSessionState(SessionState expected) { + public SessionAssertions assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); + sessionStorageService.getSession(session.getSessionModel().getId(), TEST_DATASOURCE_NAME); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); return this; } - public TestSession assertAppId(String expected) { + public SessionAssertions assertAppId(String expected) { assertEquals(expected, session.getSessionModel().getApplicationId()); return this; } - public TestSession assertJobId(String expected) { + public SessionAssertions assertJobId(String expected) { assertEquals(expected, session.getSessionModel().getJobId()); return this; } - public TestSession open(CreateSessionRequest req) { + public SessionAssertions open(CreateSessionRequest req) { session.open(req); return this; } - public TestSession close() { + public SessionAssertions close() { session.close(); return this; } } - - public static CreateSessionRequest createSessionRequest() { - return new CreateSessionRequest( - TEST_CLUSTER_NAME, - "appId", - "arn", - SparkSubmitParameters.Builder.builder(), - new HashMap<>(), - "resultIndex", - DS_NAME); - } - - public static class TestEMRServerlessClient implements EMRServerlessClient { - - private int startJobRunCalled = 0; - private int cancelJobRunCalled = 0; - - private StartJobRequest startJobRequest; - - @Override - public String startJobRun(StartJobRequest startJobRequest) { - this.startJobRequest = startJobRequest; - startJobRunCalled++; - return "jobId"; - } - - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - return null; - } - - @Override - public CancelJobRunResult cancelJobRun( - String applicationId, String jobId, boolean allowExceptionPropagation) { - cancelJobRunCalled++; - return null; - } - - public void startJobRunCalled(int expectedTimes) { - assertEquals(expectedTimes, startJobRunCalled); - } - - public void cancelJobRunCalled(int expectedTimes) { - assertEquals(expectedTimes, cancelJobRunCalled); - } - - public void assertJobNameOfLastRequest(String expectedJobName) { - assertEquals(expectedJobName, startJobRequest.getJobName()); - } - } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index d021bc7248..360018c5b0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -15,18 +15,25 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; @ExtendWith(MockitoExtension.class) public class SessionManagerTest { - @Mock private StateStore stateStore; - + @Mock private SessionStorageService sessionStorageService; + @Mock private StatementStorageService statementStorageService; @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Test public void sessionEnable() { - Assertions.assertTrue( - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()).isEnabled()); + SessionManager sessionManager = + new SessionManager( + sessionStorageService, + statementStorageService, + emrServerlessClientFactory, + sessionSetting()); + + Assertions.assertTrue(sessionManager.isEnabled()); } public static org.opensearch.sql.common.setting.Settings sessionSetting() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java new file mode 100644 index 0000000000..54451effed --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; + +import java.util.HashMap; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; + +public class SessionTestUtil { + + public static CreateSessionRequest createSessionRequest() { + return new CreateSessionRequest( + TEST_CLUSTER_NAME, + "appId", + "arn", + SparkSubmitParameters.Builder.builder(), + new HashMap<>(), + "resultIndex", + TEST_DATASOURCE_NAME); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/TestEMRServerlessClient.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/TestEMRServerlessClient.java new file mode 100644 index 0000000000..a6b0e6038e --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/TestEMRServerlessClient.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import org.junit.Assert; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; + +public class TestEMRServerlessClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + + private StartJobRequest startJobRequest; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + this.startJobRequest = startJobRequest; + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + return null; + } + + @Override + public CancelJobRunResult cancelJobRun( + String applicationId, String jobId, boolean allowExceptionPropagation) { + cancelJobRunCalled++; + return null; + } + + public void startJobRunCalled(int expectedTimes) { + Assert.assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + Assert.assertEquals(expectedTimes, cancelJobRunCalled); + } + + public void assertJobNameOfLastRequest(String expectedJobName) { + Assert.assertEquals(expectedJobName, startJobRequest.getJobName()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 3a69fa01d7..5f05eed9b9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -5,16 +5,14 @@ package org.opensearch.sql.spark.execution.statement; -import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; +import static org.opensearch.sql.spark.execution.session.SessionTestUtil.createSessionRequest; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; import static org.opensearch.sql.spark.execution.statement.StatementState.RUNNING; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -25,27 +23,41 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.session.TestEMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.test.OpenSearchIntegTestCase; public class StatementTest extends OpenSearchIntegTestCase { + private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(TEST_DATASOURCE_NAME); - private static final String DS_NAME = "mys3"; - private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); + private StatementStorageService statementStorageService; + private SessionStorageService sessionStorageService; + private TestEMRServerlessClient emrsClient = new TestEMRServerlessClient(); - private StateStore stateStore; - private InteractiveSessionTest.TestEMRServerlessClient emrsClient = - new InteractiveSessionTest.TestEMRServerlessClient(); + private SessionManager sessionManager; @Before public void setup() { - stateStore = new StateStore(client(), clusterService()); + StateStore stateStore = new StateStore(client(), clusterService()); + statementStorageService = new OpenSearchStatementStorageService(stateStore); + sessionStorageService = new OpenSearchSessionStorageService(stateStore); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + + sessionManager = + new SessionManager( + sessionStorageService, + statementStorageService, + emrServerlessClientFactory, + sessionSetting()); } @After @@ -57,21 +69,10 @@ public void clean() { @Test public void openThenCancelStatement() { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(); // submit statement - TestStatement testStatement = testStatement(st, stateStore); + TestStatement testStatement = testStatement(st, statementStorageService); testStatement .open() .assertSessionState(WAITING) @@ -81,35 +82,31 @@ public void openThenCancelStatement() { testStatement.cancel().assertSessionState(CANCELLED); } + private Statement buildStatement() { + return buildStatement(new StatementId("statementId")); + } + + private Statement buildStatement(StatementId stId) { + return Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(TEST_DATASOURCE_NAME) + .query("query") + .queryId("statementId") + .statementStorageService(statementStorageService) + .build(); + } + @Test public void openFailedBecauseConflict() { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(); st.open(); // open statement with same statement id - Statement dupSt = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement dupSt = buildStatement(); IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); assertEquals("statement already exist. statementId=statementId", exception.getMessage()); } @@ -117,18 +114,7 @@ public void openFailedBecauseConflict() { @Test public void cancelNotExistStatement() { StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(stId); st.open(); client().delete(new DeleteRequest(indexName, stId.getId())).actionGet(); @@ -142,22 +128,12 @@ public void cancelNotExistStatement() { @Test public void cancelFailedBecauseOfConflict() { StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(stId); st.open(); StatementModel running = - updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); + statementStorageService.updateStatementState( + st.getStatementModel(), CANCELLED, TEST_DATASOURCE_NAME); assertEquals(StatementState.CANCELLED, running.getStatementState()); @@ -231,21 +207,10 @@ public void cancelCancelledStatementFailed() { @Test public void cancelRunningStatementSuccess() { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(); // submit statement - TestStatement testStatement = testStatement(st, stateStore); + TestStatement testStatement = testStatement(st, statementStorageService); testStatement .open() .assertSessionState(WAITING) @@ -259,13 +224,11 @@ public void cancelRunningStatementSuccess() { @Test public void submitStatementInRunningSession() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME); StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); @@ -273,10 +236,7 @@ public void submitStatementInRunningSession() { @Test public void submitStatementInNotStartedState() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); @@ -284,12 +244,10 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.DEAD, TEST_DATASOURCE_NAME); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); @@ -301,12 +259,10 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.FAIL, TEST_DATASOURCE_NAME); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); @@ -318,10 +274,7 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); @@ -338,9 +291,7 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); // other's delete session client() @@ -354,12 +305,10 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME); StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); @@ -370,12 +319,10 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) - .createSession(createSessionRequest()); + Session session = sessionManager.createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, TEST_DATASOURCE_NAME); Optional statement = session.get(StatementId.newStatementId("not-exist-id")); assertFalse(statement.isPresent()); @@ -384,17 +331,18 @@ public void getStatementNotExist() { @RequiredArgsConstructor static class TestStatement { private final Statement st; - private final StateStore stateStore; + private final StatementStorageService statementStorageService; - public static TestStatement testStatement(Statement st, StateStore stateStore) { - return new TestStatement(st, stateStore); + public static TestStatement testStatement( + Statement st, StatementStorageService statementStorageService) { + return new TestStatement(st, statementStorageService); } public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); Optional model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -405,7 +353,7 @@ public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); Optional model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this; @@ -423,29 +371,20 @@ public TestStatement cancel() { public TestStatement run() { StatementModel model = - updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), RUNNING); + statementStorageService.updateStatementState( + st.getStatementModel(), RUNNING, TEST_DATASOURCE_NAME); st.setStatementModel(model); return this; } } private QueryRequest queryRequest() { - return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); + return new QueryRequest( + AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME), LangType.SQL, "select 1"); } private Statement createStatement(StatementId stId) { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); + Statement st = buildStatement(stId); st.open(); return st; }