diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index d38c8554ae..b0c339e93d 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -39,5 +39,5 @@ CreateAsyncQueryResponse createAsyncQuery( * @param queryId queryId. * @return {@link String} cancelledQueryId. */ - String cancelQuery(String queryId); + String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 6d3d5b6765..d304766465 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -106,11 +106,11 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { } @Override - public String cancelQuery(String queryId) { + public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional asyncQueryJobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (asyncQueryJobMetadata.isPresent()) { - return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get()); + return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext); } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java index d61ac17aa3..2bafd88b85 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -12,6 +12,7 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import org.json.JSONObject; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -54,7 +55,9 @@ protected abstract JSONObject getResponseFromResultIndex( protected abstract JSONObject getResponseFromExecutor( AsyncQueryJobMetadata asyncQueryJobMetadata); - public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); + public abstract String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); public abstract DispatchQueryResponse submit( DispatchQueryRequest request, DispatchQueryContext context); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 2654f83aad..661ebe27fc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -61,7 +62,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { emrServerlessClient.cancelJobRun( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false); return asyncQueryJobMetadata.getQueryId(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index e8413f469c..f8217142c3 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -62,9 +62,11 @@ public DispatchQueryResponse submit( long startTime = System.currentTimeMillis(); try { IndexQueryDetails indexDetails = context.getIndexQueryDetails(); - FlintIndexMetadata indexMetadata = getFlintIndexMetadata(indexDetails); + FlintIndexMetadata indexMetadata = + getFlintIndexMetadata(indexDetails, context.getAsyncQueryRequestContext()); - getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); + getIndexOp(dispatchQueryRequest, indexDetails) + .apply(indexMetadata, context.getAsyncQueryRequestContext()); String asyncQueryId = storeIndexDMLResult( @@ -146,9 +148,11 @@ private FlintIndexOp getIndexOp( } } - private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) { + private FlintIndexMetadata getFlintIndexMetadata( + IndexQueryDetails indexDetails, AsyncQueryRequestContext asyncQueryRequestContext) { Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexDetails.openSearchIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(indexDetails.openSearchIndexName())) { throw new IllegalStateException( String.format( @@ -174,7 +178,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException("can't cancel index DML query"); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index ec43bccf11..9a9baedde2 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -71,7 +72,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String queryId = asyncQueryJobMetadata.getQueryId(); getStatementByQueryId( asyncQueryJobMetadata.getSessionId(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 99984ecc46..38145a143e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -8,6 +8,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -51,10 +52,13 @@ public RefreshQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String datasourceName = asyncQueryJobMetadata.getDatasourceName(); Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(asyncQueryJobMetadata.getIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + asyncQueryJobMetadata.getIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(asyncQueryJobMetadata.getIndexName())) { throw new IllegalStateException( String.format( @@ -62,7 +66,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); - jobCancelOp.apply(indexMetadata); + jobCancelOp.apply(indexMetadata, asyncQueryRequestContext); return asyncQueryJobMetadata.getQueryId(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index a424db4c34..a6fdd3f102 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -162,9 +162,11 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) .getQueryResponse(asyncQueryJobMetadata); } - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) - .cancelJob(asyncQueryJobMetadata); + .cancelJob(asyncQueryJobMetadata, asyncQueryRequestContext); } private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 2fbf2466da..80d4be27cf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,6 +12,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -46,7 +47,9 @@ public StreamingQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop job, using" + " VACUUM statement to stop job and delete data"); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java index ad274e429e..ece14c2a7b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.flint; import java.util.Map; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Interface for FlintIndexMetadataReader */ @@ -15,16 +16,22 @@ public interface FlintIndexMetadataService { * Retrieves a map of {@link FlintIndexMetadata} instances matching the specified index pattern. * * @param indexPattern indexPattern. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService * @return A map of {@link FlintIndexMetadata} instances against indexName, each providing * metadata access for a matched index. Returns an empty list if no indices match the pattern. */ - Map getFlintIndexMetadata(String indexPattern); + Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext); /** * Performs validation and updates flint index to manual refresh. * * @param indexName indexName. * @param flintIndexOptions flintIndexOptions. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService */ - void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions); + void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java index 94647f4e07..3872f2d5a0 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java @@ -6,20 +6,58 @@ package org.opensearch.sql.spark.flint; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; /** * Abstraction over flint index state storage. Flint index state will maintain the status of each * flint index. */ public interface FlintIndexStateModelService { - FlintIndexStateModel createFlintIndexStateModel(FlintIndexStateModel flintIndexStateModel); - Optional getFlintIndexStateModel(String id, String datasourceName); + /** + * Create Flint index state record + * + * @param flintIndexStateModel the model to be saved + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return saved model + */ + FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, AsyncQueryRequestContext asyncQueryRequestContext); + /** + * Get Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return retrieved model + */ + Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); + + /** + * Update Flint index state record + * + * @param flintIndexStateModel the model to be updated + * @param flintIndexState new state + * @param datasourceName Datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return Updated model + */ FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName); + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext); - boolean deleteFlintIndexStateModel(String id, String datasourceName); + /** + * Delete Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return true if deleted, otherwise false + */ + boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 244f4aee11..78d217b8dc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -16,6 +16,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -33,30 +34,33 @@ public abstract class FlintIndexOp { private final EMRServerlessClientFactory emrServerlessClientFactory; /** Apply operation on {@link FlintIndexMetadata} */ - public void apply(FlintIndexMetadata metadata) { + public void apply( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // todo, remove this logic after IndexState feature is enabled in Flint. Optional latestId = metadata.getLatestId(); if (latestId.isEmpty()) { - takeActionWithoutOCC(metadata); + takeActionWithoutOCC(metadata, asyncQueryRequestContext); } else { - FlintIndexStateModel initialFlintIndexStateModel = getFlintIndexStateModel(latestId.get()); + FlintIndexStateModel initialFlintIndexStateModel = + getFlintIndexStateModel(latestId.get(), asyncQueryRequestContext); // 1.validate state. validFlintIndexInitialState(initialFlintIndexStateModel); // 2.begin, move to transitioning state FlintIndexStateModel transitionedFlintIndexStateModel = - moveToTransitioningState(initialFlintIndexStateModel); + moveToTransitioningState(initialFlintIndexStateModel, asyncQueryRequestContext); // 3.runOp try { - runOp(metadata, transitionedFlintIndexStateModel); - commit(transitionedFlintIndexStateModel); + runOp(metadata, transitionedFlintIndexStateModel, asyncQueryRequestContext); + commit(transitionedFlintIndexStateModel, asyncQueryRequestContext); } catch (Throwable e) { LOG.error("Rolling back transient log due to transaction operation failure", e); try { flintIndexStateModelService.updateFlintIndexState( transitionedFlintIndexStateModel, initialFlintIndexStateModel.getIndexState(), - datasourceName); + datasourceName, + asyncQueryRequestContext); } catch (Exception ex) { LOG.error("Failed to rollback transient log", ex); } @@ -66,9 +70,11 @@ public void apply(FlintIndexMetadata metadata) { } @NotNull - private FlintIndexStateModel getFlintIndexStateModel(String latestId) { + private FlintIndexStateModel getFlintIndexStateModel( + String latestId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional flintIndexOptional = - flintIndexStateModelService.getFlintIndexStateModel(latestId, datasourceName); + flintIndexStateModelService.getFlintIndexStateModel( + latestId, datasourceName, asyncQueryRequestContext); if (flintIndexOptional.isEmpty()) { String errorMsg = String.format(Locale.ROOT, "no state found. docId: %s", latestId); LOG.error(errorMsg); @@ -77,7 +83,8 @@ private FlintIndexStateModel getFlintIndexStateModel(String latestId) { return flintIndexOptional.get(); } - private void takeActionWithoutOCC(FlintIndexMetadata metadata) { + private void takeActionWithoutOCC( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // take action without occ. FlintIndexStateModel fakeModel = FlintIndexStateModel.builder() @@ -89,7 +96,7 @@ private void takeActionWithoutOCC(FlintIndexMetadata metadata) { .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - runOp(metadata, fakeModel); + runOp(metadata, fakeModel, asyncQueryRequestContext); } private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { @@ -103,13 +110,14 @@ private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { } } - private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flintIndex) { + private FlintIndexStateModel moveToTransitioningState( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Moving to transitioning state before committing."); FlintIndexState transitioningState = transitioningState(); try { flintIndex = flintIndexStateModelService.updateFlintIndexState( - flintIndex, transitioningState(), datasourceName); + flintIndex, transitioningState(), datasourceName, asyncQueryRequestContext); } catch (Exception e) { String errorMsg = String.format(Locale.ROOT, "Moving to transition state:%s failed.", transitioningState); @@ -119,16 +127,18 @@ private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flint return flintIndex; } - private void commit(FlintIndexStateModel flintIndex) { + private void commit( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Committing the transaction and moving to stable state."); FlintIndexState stableState = stableState(); try { if (stableState == FlintIndexState.NONE) { LOG.info("Deleting index state with docId: " + flintIndex.getLatestId()); flintIndexStateModelService.deleteFlintIndexStateModel( - flintIndex.getLatestId(), datasourceName); + flintIndex.getLatestId(), datasourceName, asyncQueryRequestContext); } else { - flintIndexStateModelService.updateFlintIndexState(flintIndex, stableState, datasourceName); + flintIndexStateModelService.updateFlintIndexState( + flintIndex, stableState, datasourceName, asyncQueryRequestContext); } } catch (Exception e) { String errorMsg = @@ -192,7 +202,10 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) /** get transitioningState */ abstract FlintIndexState transitioningState(); - abstract void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex); + abstract void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext); /** get stableState */ abstract FlintIndexState stableState(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 9955320253..4a00195ebf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -48,11 +49,14 @@ FlintIndexState transitioningState() { @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Running alter index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); this.flintIndexMetadataService.updateIndexToManualRefresh( - flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions); + flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions, asyncQueryRequestContext); cancelStreamingJob(flintIndexStateModel); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 02c8e39c66..504a8f93c9 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -38,7 +39,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 6613c29870..fc9b644fc7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -40,7 +41,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index a0ef955adf..06aaf8ef9f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -42,7 +43,10 @@ FlintIndexState transitioningState() { } @Override - public void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) { + public void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.info("Vacuuming Flint index {}", flintIndexMetadata.getOpensearchIndexName()); flintIndexClient.deleteIndex(flintIndexMetadata.getOpensearchIndexName()); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index d82d3bdab7..ff92762a7c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -249,7 +249,8 @@ public void createAlterIndexQuery() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verify(flintIndexMetadataService) - .updateIndexToManualRefresh(eq(indexName), flintIndexOptionsArgumentCaptor.capture()); + .updateIndexToManualRefresh( + eq(indexName), flintIndexOptionsArgumentCaptor.capture(), eq(asyncQueryRequestContext)); FlintIndexOptions flintIndexOptions = flintIndexOptionsArgumentCaptor.getValue(); assertFalse(flintIndexOptions.autoRefresh()); verifyCancelJobRunCalled(); @@ -430,7 +431,7 @@ public void cancelInteractiveQuery() { when(statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED)) .thenReturn(canceledStatementModel); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verify(statementStorageService).updateStatementState(statementModel, StatementState.CANCELLED); @@ -441,14 +442,15 @@ public void cancelIndexDMLQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(DROP_INDEX_JOB_ID)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test public void cancelRefreshQuery() { givenJobMetadataExists( getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME)); - when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME)) + when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( INDEX_NAME, @@ -463,7 +465,7 @@ public void cancelRefreshQuery() { new GetJobRunResult() .withJobRun(new JobRun().withJobRunId(JOB_ID).withState("Cancelled"))); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -475,7 +477,8 @@ public void cancelStreamingQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.STREAMING)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test @@ -483,7 +486,7 @@ public void cancelBatchQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(JOB_ID)); givenCancelJobRunSucceed(); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -500,7 +503,7 @@ private void givenSparkExecutionEngineConfigIsSupplied() { } private void givenFlintIndexMetadataExists(String indexName) { - when(flintIndexMetadataService.getFlintIndexMetadata(indexName)) + when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( indexName, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index dbc51bb0ad..5d8d9a3b63 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -206,7 +206,8 @@ void testCancelJobWithJobNotFound() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( - AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID)); + AsyncQueryNotFoundException.class, + () -> jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext)); Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); @@ -218,9 +219,10 @@ void testCancelJobWithJobNotFound() { void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(getAsyncQueryJobMetadata())); - when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata())).thenReturn(EMR_JOB_ID); + when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata(), asyncQueryRequestContext)) + .thenReturn(EMR_JOB_ID); - String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); + String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(sparkExecutionEngineConfigSupplier); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 877d6ec32b..9a3c4e663e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.datasource.model.DataSourceStatus.ACTIVE; @@ -27,6 +28,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -50,6 +52,7 @@ class IndexDMLHandlerTest { @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks IndexDMLHandler indexDMLHandler; @@ -82,8 +85,10 @@ public void testWhenIndexDetailsAreNotFound() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - Mockito.when(flintIndexMetadataService.getFlintIndexMetadata(any())) + Mockito.when( + flintIndexMetadataService.getFlintIndexMetadata(any(), eq(asyncQueryRequestContext))) .thenReturn(new HashMap<>()); DispatchQueryResponse dispatchQueryResponse = @@ -107,10 +112,12 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); HashMap flintMetadataMap = new HashMap<>(); flintMetadataMap.put(indexQueryDetails.openSearchIndexName(), flintIndexMetadata); - when(flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName())) + when(flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext)) .thenReturn(flintMetadataMap); indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a7a79c758e..592309cb75 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -871,7 +871,8 @@ void testCancelJob() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } @@ -884,7 +885,8 @@ void testCancelQueryWithSession() { String queryId = sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); verify(statement, times(1)).cancel(); @@ -900,7 +902,8 @@ void testCancelQueryWithInvalidSession() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(session); @@ -916,7 +919,8 @@ void testCancelQueryWithInvalidStatementId() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(statement); @@ -933,7 +937,8 @@ void testCancelQueryWithNoSessionId() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 0c82733ae6..8105629822 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.xcontent.XContentSerializerUtil; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -28,21 +29,26 @@ public class FlintIndexOpTest { @Mock private FlintIndexStateModelService flintIndexStateModelService; @Mock private EMRServerlessClientFactory mockEmrServerlessClientFactory; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Test public void testApplyWithTransitioningStateFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenThrow(new RuntimeException("Transitioning state failed")); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "Moving to transition state:DELETING failed.", illegalStateException.getMessage()); @@ -53,9 +59,11 @@ public void testApplyWithCommitFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -65,7 +73,9 @@ public void testApplyWithCommitFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -76,9 +86,11 @@ public void testApplyWithRollBackFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -87,7 +99,9 @@ public void testApplyWithRollBackFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -125,7 +139,10 @@ FlintIndexState transitioningState() { } @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) {} + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) {} @Override FlintIndexState stableState() { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java index 60fa13dc93..26858c18fe 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -38,6 +39,7 @@ class FlintIndexOpVacuumTest { @Mock EMRServerlessClientFactory emrServerlessClientFactory; @Mock FlintIndexStateModel flintIndexStateModel; @Mock FlintIndexStateModel transitionedFlintIndexStateModel; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; RuntimeException testException = new RuntimeException("Test Exception"); @@ -55,110 +57,154 @@ public void setUp() { @Test public void testApplyWithEmptyLatestId() { - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } @Test public void testApplyWithFlintIndexStateNotFound() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.empty()); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithNotDeletedState() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.ACTIVE); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithUpdateFlintIndexStateThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithRunOpThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); verify(flintIndexStateModelService) .updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME); + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext); } @Test public void testApplyWithRunOpThrowAndRollbackThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); when(flintIndexStateModelService.updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME)) + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithDeleteFlintIndexStateModelThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); - when(flintIndexStateModelService.deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.deleteFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyHappyPath() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); when(transitionedFlintIndexStateModel.getLatestId()).thenReturn(LATEST_ID); - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext); - verify(flintIndexStateModelService).deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME); + verify(flintIndexStateModelService) + .deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java index 31b1ecb49c..2dd0a4a7cf 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java @@ -17,6 +17,7 @@ import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -29,6 +30,8 @@ public class FlintStreamingJobHouseKeeperTask implements Runnable { private final DataSourceService dataSourceService; private final FlintIndexMetadataService flintIndexMetadataService; private final FlintIndexOpFactory flintIndexOpFactory; + private final NullAsyncQueryRequestContext nullAsyncQueryRequestContext = + new NullAsyncQueryRequestContext(); private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobHouseKeeperTask.class); protected static final AtomicBoolean isRunning = new AtomicBoolean(false); @@ -91,7 +94,9 @@ private void dropAutoRefreshIndex( String autoRefreshIndex, FlintIndexMetadata flintIndexMetadata, String datasourceName) { // When the datasource is deleted. Possibly Replace with VACUUM Operation. LOGGER.info("Attempting to drop auto refresh index: {}", autoRefreshIndex); - flintIndexOpFactory.getDrop(datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getDrop(datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully dropped index: {}", autoRefreshIndex); } @@ -100,7 +105,9 @@ private void alterAutoRefreshIndex( LOGGER.info("Attempting to alter index: {}", autoRefreshIndex); FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); - flintIndexOpFactory.getAlter(flintIndexOptions, datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getAlter(flintIndexOptions, datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully altered index: {}", autoRefreshIndex); } @@ -119,7 +126,7 @@ private String getDataSourceName(FlintIndexMetadata flintIndexMetadata) { private Map getAllAutoRefreshIndices() { Map flintIndexMetadataHashMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_*", nullAsyncQueryRequestContext); return flintIndexMetadataHashMap.entrySet().stream() .filter(entry -> entry.getValue().getFlintIndexOptions().autoRefresh()) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java index 893b33b39d..b8352d15b2 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java @@ -33,6 +33,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.client.Client; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Implementation of {@link FlintIndexMetadataService} */ @@ -49,7 +50,8 @@ public class FlintIndexMetadataServiceImpl implements FlintIndexMetadataService Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH, WATERMARK_DELAY, CHECKPOINT_LOCATION)); @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexPattern).get(); Map indexMetadataMap = new HashMap<>(); @@ -73,7 +75,10 @@ public Map getFlintIndexMetadata(String indexPattern } @Override - public void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions) { + public void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexName).get(); Map flintMetadataMap = diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java index 5781c3e44b..eba338e912 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -7,6 +7,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -20,7 +21,8 @@ public class OpenSearchFlintIndexStateModelService implements FlintIndexStateMod public FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName) { + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.updateState( flintIndexStateModel, flintIndexState, @@ -29,14 +31,16 @@ public FlintIndexStateModel updateFlintIndexState( } @Override - public Optional getFlintIndexStateModel(String id, String datasourceName) { + public Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.get( id, serializer::fromXContent, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } @Override public FlintIndexStateModel createFlintIndexStateModel( - FlintIndexStateModel flintIndexStateModel) { + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.create( flintIndexStateModel.getId(), flintIndexStateModel, @@ -45,7 +49,8 @@ public FlintIndexStateModel createFlintIndexStateModel( } @Override - public boolean deleteFlintIndexStateModel(String id, String datasourceName) { + public boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.delete(id, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..ce80351f70 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -13,6 +13,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -41,7 +42,9 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { - String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + String jobId = + asyncQueryExecutorService.cancelQuery( + request.getQueryId(), new NullAsyncQueryRequestContext()); listener.onResponse( new CancelAsyncQueryActionResponse( String.format("Deleted async query with id: %s", jobId))); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 3ff806bf50..ede8a348b4 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -71,7 +71,8 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { emrsClient.getJobRunResultCalled(1); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); emrsClient.cancelJobRunCalled(1); } @@ -163,7 +164,8 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 2eed7b13a0..29c42446b3 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -152,7 +152,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -326,7 +328,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -901,7 +905,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop" + " job, using VACUUM statement to stop job and delete data", @@ -944,7 +950,9 @@ public GetJobRunResult getJobRunResult( flintIndexJob.refreshing(); // 2. Cancel query - String cancelResponse = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelResponse = + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext); assertNotNull(cancelResponse); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -992,7 +1000,9 @@ public GetJobRunResult getJobRunResult( IllegalStateException illegalStateException = Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); Assertions.assertEquals( "Transaction failed as flint index is not in a valid state.", illegalStateException.getMessage()); @@ -1038,6 +1048,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. Cancel query Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext)); } } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 6c82188ee6..0dc8f02820 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -18,6 +18,7 @@ public class MockFlintSparkJob { private FlintIndexStateModel stateModel; private FlintIndexStateModelService flintIndexStateModelService; private String datasource; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); public MockFlintSparkJob( FlintIndexStateModelService flintIndexStateModelService, String latestId, String datasource) { @@ -34,12 +35,15 @@ public MockFlintSparkJob( .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - stateModel = flintIndexStateModelService.createFlintIndexStateModel(stateModel); + stateModel = + flintIndexStateModelService.createFlintIndexStateModel( + stateModel, asyncQueryRequestContext); } public void transition(FlintIndexState newState) { stateModel = - flintIndexStateModelService.updateFlintIndexState(stateModel, newState, datasource); + flintIndexStateModelService.updateFlintIndexState( + stateModel, newState, datasource, asyncQueryRequestContext); } public void refreshing() { @@ -68,7 +72,8 @@ public void deleted() { public void assertState(FlintIndexState expected) { Optional stateModelOpt = - flintIndexStateModelService.getFlintIndexStateModel(stateModel.getId(), datasource); + flintIndexStateModelService.getFlintIndexStateModel( + stateModel.getId(), datasource, asyncQueryRequestContext); assertTrue(stateModelOpt.isPresent()); assertEquals(expected, stateModelOpt.get().getIndexState()); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index c5964a61e3..0a3a180932 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -20,6 +20,7 @@ import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceSpec; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.asyncquery.model.MockFlintIndex; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; @@ -393,13 +394,16 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataService() { @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { throw new RuntimeException("Couldn't fetch details from ElasticSearch"); } @Override public void updateIndexToManualRefresh( - String indexName, FlintIndexOptions flintIndexOptions) {} + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) {} }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java index f6baa82dd2..b1321cc132 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; @@ -39,6 +40,8 @@ public class FlintIndexMetadataServiceImplTest { @Mock(answer = RETURNS_DEEP_STUBS) private Client client; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @SneakyThrows @Test void testGetJobIdFromFlintSkippingIndexMetadata() { @@ -56,8 +59,11 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fhelvq7peuao0", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -80,8 +86,11 @@ void testGetJobIdFromFlintSkippingIndexMetadataWithIndexState() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + FlintIndexMetadata metadata = indexMetadataMap.get(indexQueryDetails.openSearchIndexName()); Assertions.assertEquals("00fhelvq7peuao0", metadata.getJobId()); } @@ -103,8 +112,11 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { .indexType(FlintIndexType.COVERING) .build(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fdmvv9hp8u0o0q", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -126,8 +138,11 @@ void testGetJobIDWithNPEException() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build(); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); } @@ -148,8 +163,10 @@ void testGetJobIDWithNPEExceptionForMultipleIndices() { indexMappingsMap.put(indexName, mappings); mockNodeClientIndicesMappings("flint_mys3*", indexMappingsMap); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*", asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); Assertions.assertTrue( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java index 977f77b397..4faff41fe6 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java @@ -16,6 +16,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -30,6 +31,7 @@ public class OpenSearchFlintIndexStateModelServiceTest { @Mock FlintIndexState flintIndexState; @Mock FlintIndexStateModel responseFlintIndexStateModel; @Mock FlintIndexStateModelXContentSerializer flintIndexStateModelXContentSerializer; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks OpenSearchFlintIndexStateModelService openSearchFlintIndexStateModelService; @@ -40,7 +42,7 @@ void updateFlintIndexState() { FlintIndexStateModel result = openSearchFlintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, flintIndexState, DATASOURCE); + flintIndexStateModel, flintIndexState, DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -51,7 +53,8 @@ void getFlintIndexStateModel() { .thenReturn(Optional.of(responseFlintIndexStateModel)); Optional result = - openSearchFlintIndexStateModelService.getFlintIndexStateModel("ID", DATASOURCE); + openSearchFlintIndexStateModelService.getFlintIndexStateModel( + "ID", DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result.get()); } @@ -63,7 +66,8 @@ void createFlintIndexStateModel() { when(flintIndexStateModel.getDatasourceName()).thenReturn(DATASOURCE); FlintIndexStateModel result = - openSearchFlintIndexStateModelService.createFlintIndexStateModel(flintIndexStateModel); + openSearchFlintIndexStateModelService.createFlintIndexStateModel( + flintIndexStateModel, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -73,7 +77,8 @@ void deleteFlintIndexStateModel() { when(mockStateStore.delete(any(), any())).thenReturn(true); boolean result = - openSearchFlintIndexStateModelService.deleteFlintIndexStateModel(ID, DATASOURCE); + openSearchFlintIndexStateModelService.deleteFlintIndexStateModel( + ID, DATASOURCE, asyncQueryRequestContext); assertTrue(result); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..a2581fdea2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -24,6 +26,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -36,7 +39,6 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private TransportCancelAsyncQueryRequestAction action; @Mock private Task task; @Mock private ActionListener actionListener; - @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; @Captor @@ -54,8 +56,12 @@ public void setUp() { @Test public void testDoExecute() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery( + eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class))) + .thenReturn(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); @@ -66,8 +72,12 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + doThrow(new RuntimeException("Error")) + .when(asyncQueryExecutorService) + .cancelQuery(eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof RuntimeException);