diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index e4818d737c..14107712f1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -63,7 +63,7 @@ public CreateAsyncQueryResponse createAsyncQuery( .indexName(dispatchQueryResponse.getIndexName()) .build()); return new CreateAsyncQueryResponse( - dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); + dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index 2ac67b96ba..16a243edcb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -31,11 +31,10 @@ public class OpensearchAsyncQueryJobMetadataStorageService @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); stateStore.create( asyncQueryJobMetadata, AsyncQueryJobMetadata::copy, - OpenSearchStateStoreUtil.getIndexName(queryId.getDataSourceName())); + OpenSearchStateStoreUtil.getIndexName(asyncQueryJobMetadata.getDatasourceName())); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 08770c7588..be8b543f98 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -21,7 +21,7 @@ @SuperBuilder @EqualsAndHashCode(callSuper = false) public class AsyncQueryJobMetadata extends StateModel { - private final AsyncQueryId queryId; + private final String queryId; private final String applicationId; private final String jobId; private final String resultIndex; @@ -59,6 +59,6 @@ public static AsyncQueryJobMetadata copy( @Override public String getId() { - return queryId.docId(); + return "qid" + queryId; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 8d3803045b..3bdbd8ca74 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -63,7 +63,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { emrServerlessClient.cancelJobRun( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false); - return asyncQueryJobMetadata.getQueryId().getId(); + return asyncQueryJobMetadata.getQueryId(); } @Override @@ -93,7 +93,12 @@ public DispatchQueryResponse submit( dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); MetricUtils.incrementNumericalMetric(MetricName.EMR_BATCH_QUERY_JOBS_CREATION_COUNT); - return new DispatchQueryResponse( - context.getQueryId(), jobId, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(jobId) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java new file mode 100644 index 0000000000..8f21a2a04a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Generates QueryId by embedding Datasource name and random UUID */ +public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider { + + @Override + public String getQueryId(DispatchQueryRequest dispatchQueryRequest) { + return AsyncQueryId.newAsyncQueryId(dispatchQueryRequest.getDatasource()).getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 72980dcb1f..199f24977c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -16,13 +16,13 @@ import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -65,39 +65,51 @@ public DispatchQueryResponse submit( getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); - AsyncQueryId asyncQueryId = + String asyncQueryId = storeIndexDMLResult( + context.getQueryId(), dispatchQueryRequest, dataSourceMetadata, JobRunState.SUCCESS.toString(), StringUtils.EMPTY, getElapsedTimeSince(startTime)); - return new DispatchQueryResponse( - asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(asyncQueryId) + .jobId(DML_QUERY_JOB_ID) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } catch (Exception e) { LOG.error(e.getMessage()); - AsyncQueryId asyncQueryId = + String asyncQueryId = storeIndexDMLResult( + context.getQueryId(), dispatchQueryRequest, dataSourceMetadata, JobRunState.FAILED.toString(), e.getMessage(), getElapsedTimeSince(startTime)); - return new DispatchQueryResponse( - asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(asyncQueryId) + .jobId(DML_QUERY_JOB_ID) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } } - private AsyncQueryId storeIndexDMLResult( + private String storeIndexDMLResult( + String queryId, DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata, String status, String error, long queryRunTime) { - AsyncQueryId asyncQueryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); IndexDMLResult indexDMLResult = IndexDMLResult.builder() - .queryId(asyncQueryId.getId()) + .queryId(queryId) .status(status) .error(error) .datasourceName(dispatchQueryRequest.getDatasource()) @@ -105,7 +117,7 @@ private AsyncQueryId storeIndexDMLResult( .updateTime(System.currentTimeMillis()) .build(); indexDMLResultStorageService.createIndexDMLResult(indexDMLResult); - return asyncQueryId; + return queryId; } private long getElapsedTimeSince(long startTime) { @@ -143,7 +155,7 @@ private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 552ddeb76e..e41f4a49fd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -49,7 +49,7 @@ public class InteractiveQueryHandler extends AsyncQueryHandler { @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } @@ -57,7 +57,7 @@ protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQuery @Override protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { JSONObject result = new JSONObject(); - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); StatementState statementState = statement.getStatementState(); result.put(STATUS_FIELD, statementState.getState()); @@ -67,7 +67,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob @Override public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); return queryId; } @@ -118,11 +118,14 @@ public DispatchQueryResponse submit( context.getQueryId(), dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); - return new DispatchQueryResponse( - context.getQueryId(), - session.getSessionModel().getJobId(), - dataSourceMetadata.getResultIndex(), - session.getSessionId().getSessionId()); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(session.getSessionModel().getJobId()) + .resultIndex(dataSourceMetadata.getResultIndex()) + .sessionId(session.getSessionId().getSessionId()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } private Statement getStatementByQueryId(String sid, String qid) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java new file mode 100644 index 0000000000..2167eb6b7a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Interface for extension point to specify queryId. Called when new query is executed. */ +public interface QueryIdProvider { + String getQueryId(DispatchQueryRequest dispatchQueryRequest); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index edb0a3f507..69c21321a6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -53,7 +53,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); jobCancelOp.apply(indexMetadata); - return asyncQueryJobMetadata.getQueryId().getId(); + return asyncQueryJobMetadata.getQueryId(); } @Override @@ -61,13 +61,14 @@ public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); - return new DispatchQueryResponse( - resp.getQueryId(), - resp.getJobId(), - resp.getResultIndex(), - resp.getSessionId(), - dataSourceMetadata.getName(), - JobType.BATCH, - context.getIndexQueryDetails().openSearchIndexName()); + return DispatchQueryResponse.builder() + .queryId(resp.getQueryId()) + .jobId(resp.getJobId()) + .resultIndex(resp.getResultIndex()) + .sessionId(resp.getSessionId()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.BATCH) + .indexName(context.getIndexQueryDetails().openSearchIndexName()) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index b6f5bcceb3..67d2767493 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -12,7 +12,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -36,6 +35,7 @@ public class SparkQueryDispatcher { private final DataSourceService dataSourceService; private final SessionManager sessionManager; private final QueryHandlerFactory queryHandlerFactory; + private final QueryIdProvider queryIdProvider; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = @@ -59,12 +59,12 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } } - private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( + private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + .queryId(queryIdProvider.getQueryId(dispatchQueryRequest)); } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 886e7d176a..0649e81418 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,7 +12,6 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.utils.MetricUtils; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -82,13 +81,13 @@ public DispatchQueryResponse submit( dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); MetricUtils.incrementNumericalMetric(MetricName.EMR_STREAMING_QUERY_JOBS_CREATION_COUNT); - return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - jobId, - dataSourceMetadata.getResultIndex(), - null, - dataSourceMetadata.getName(), - JobType.STREAMING, - indexQueryDetails.openSearchIndexName()); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(jobId) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.STREAMING) + .indexName(indexQueryDetails.openSearchIndexName()) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java index d3400d86bf..7b694e47f0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java @@ -9,12 +9,11 @@ import lombok.Builder; import lombok.Getter; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Getter @Builder public class DispatchQueryContext { - private final AsyncQueryId queryId; + private final String queryId; private final DataSourceMetadata dataSourceMetadata; private final Map tags; private final IndexQueryDetails indexQueryDetails; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 2c39aab1d4..b97d9fd7b0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -1,37 +1,16 @@ package org.opensearch.sql.spark.dispatcher.model; +import lombok.Builder; import lombok.Getter; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Getter +@Builder public class DispatchQueryResponse { - private final AsyncQueryId queryId; + private final String queryId; private final String jobId; private final String resultIndex; private final String sessionId; private final String datasourceName; private final JobType jobType; private final String indexName; - - public DispatchQueryResponse( - AsyncQueryId queryId, String jobId, String resultIndex, String sessionId) { - this(queryId, jobId, resultIndex, sessionId, null, JobType.INTERACTIVE, null); - } - - public DispatchQueryResponse( - AsyncQueryId queryId, - String jobId, - String resultIndex, - String sessionId, - String datasourceName, - JobType jobType, - String indexName) { - this.queryId = queryId; - this.jobId = jobId; - this.resultIndex = resultIndex; - this.sessionId = sessionId; - this.datasourceName = datasourceName; - this.jobType = jobType; - this.indexName = indexName; - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 8758bcb4a3..9920fb9aec 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -94,7 +94,7 @@ public StatementId submit(QueryRequest request) { } else { sessionModel = model.get(); if (!END_STATE.contains(sessionModel.getSessionState())) { - String qid = request.getQueryId().getId(); + String qid = request.getQueryId(); StatementId statementId = newStatementId(qid); Statement st = Statement.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java index c365265224..db2e96b3cd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -6,12 +6,11 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.rest.model.LangType; @Data public class QueryRequest { - private final AsyncQueryId queryId; + private final String queryId; private final LangType langType; private final String query; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java index a4209a0ce7..39a1ec83e4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java @@ -20,7 +20,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -37,7 +36,7 @@ public XContentBuilder toXContent(AsyncQueryJobMetadata jobMetadata, ToXContent. throws IOException { return XContentFactory.jsonBuilder() .startObject() - .field(QUERY_ID, jobMetadata.getQueryId().getId()) + .field(QUERY_ID, jobMetadata.getQueryId()) .field(TYPE, TYPE_JOBMETA) .field(JOB_ID, jobMetadata.getJobId()) .field(APPLICATION_ID, jobMetadata.getApplicationId()) @@ -59,7 +58,7 @@ public AsyncQueryJobMetadata fromXContent(XContentParser parser, long seqNo, lon parser.nextToken(); switch (fieldName) { case QUERY_ID: - builder.queryId(new AsyncQueryId(parser.textOrNull())); + builder.queryId(parser.textOrNull()); break; case JOB_ID: builder.jobId(parser.textOrNull()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 615a914fee..85af39b52d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -25,7 +25,9 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; +import org.opensearch.sql.spark.dispatcher.QueryIdProvider; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; @@ -82,8 +84,15 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) { public SparkQueryDispatcher sparkQueryDispatcher( DataSourceService dataSourceService, SessionManager sessionManager, - QueryHandlerFactory queryHandlerFactory) { - return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + QueryHandlerFactory queryHandlerFactory, + QueryIdProvider queryIdProvider) { + return new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + } + + @Provides + public QueryIdProvider queryIdProvider() { + return new DatasourceEmbeddedQueryIdProvider(); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 2b84f967f0..43dd4880e7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,7 +11,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -31,7 +30,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; @@ -41,6 +39,7 @@ import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; @@ -55,7 +54,7 @@ public class AsyncQueryExecutorServiceImplTest { @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Mock private RequestContext requestContext; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + private final String QUERY_ID = "QUERY_ID"; @BeforeEach void setUp() { @@ -89,7 +88,12 @@ void testCreateAsyncQuery() { TEST_CLUSTER_NAME, sparkSubmitParameterModifier); when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); + .thenReturn( + DispatchQueryResponse.builder() + .queryId(QUERY_ID) + .jobId(EMR_JOB_ID) + .jobType(JobType.INTERACTIVE) + .build()); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest, requestContext); @@ -99,7 +103,7 @@ void testCreateAsyncQuery() { verify(sparkExecutionEngineConfigSupplier, times(1)) .getSparkExecutionEngineConfig(requestContext); verify(sparkQueryDispatcher, times(1)).dispatch(expectedDispatchQueryRequest); - Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId()); } @Test @@ -115,7 +119,12 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { modifier, TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); + .thenReturn( + DispatchQueryResponse.builder() + .queryId(QUERY_ID) + .jobId(EMR_JOB_ID) + .jobType(JobType.INTERACTIVE) + .build()); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index b15a911364..4991095aca 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -58,6 +58,7 @@ import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; @@ -262,7 +263,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( statementStorageService, emrServerlessClientFactory, pluginSettings), - queryHandlerFactory); + queryHandlerFactory, + new DatasourceEmbeddedQueryIdProvider()); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 431f5b2b15..cd7a11149d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -40,15 +40,16 @@ public void setup() { public void testStoreJobMetadata() { AsyncQueryJobMetadata expected = AsyncQueryJobMetadata.builder() - .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME).getId()) .jobId(EMR_JOB_ID) .applicationId(EMRS_APPLICATION_ID) .resultIndex(MOCK_RESULT_INDEX) + .datasourceName(DS_NAME) .build(); opensearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); assertTrue(actual.isPresent()); assertEquals(expected, actual.get()); @@ -60,16 +61,17 @@ public void testStoreJobMetadata() { public void testStoreJobMetadataWithResultExtraData() { AsyncQueryJobMetadata expected = AsyncQueryJobMetadata.builder() - .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME).getId()) .jobId(EMR_JOB_ID) .applicationId(EMRS_APPLICATION_ID) .resultIndex(MOCK_RESULT_INDEX) .sessionId(MOCK_SESSION_ID) + .datasourceName(DS_NAME) .build(); opensearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); assertTrue(actual.isPresent()); assertEquals(expected, actual.get()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 7d43ccc7e3..2e536ef6b3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; @@ -43,12 +44,23 @@ @ExtendWith(MockitoExtension.class) class IndexDMLHandlerTest { + private static final String QUERY_ID = "QUERY_ID"; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private FlintIndexMetadataService flintIndexMetadataService; @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @InjectMocks IndexDMLHandler indexDMLHandler; + + private static final DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setName("mys3") + .setDescription("test description") + .setConnector(DataSourceType.S3GLUE) + .setDataSourceStatus(ACTIVE) + .build(); + @Test public void getResponseFromExecutor() { JSONObject result = new IndexDMLHandler(null, null, null, null).getResponseFromExecutor(null); @@ -59,28 +71,7 @@ public void getResponseFromExecutor() { @Test public void testWhenIndexDetailsAreNotFound() { - IndexDMLHandler indexDMLHandler = - new IndexDMLHandler( - jobExecutionResponseReader, - flintIndexMetadataService, - indexDMLResultStorageService, - flintIndexOpFactory); - DispatchQueryRequest dispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "DROP INDEX", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); - DataSourceMetadata metadata = - new DataSourceMetadata.Builder() - .setName("mys3") - .setDescription("test description") - .setConnector(DataSourceType.S3GLUE) - .setDataSourceStatus(ACTIVE) - .build(); + DispatchQueryRequest dispatchQueryRequest = getDispatchQueryRequest("DROP INDEX"); IndexQueryDetails indexQueryDetails = IndexQueryDetails.builder() .mvName("mys3.default.http_logs_metrics") @@ -88,6 +79,7 @@ public void testWhenIndexDetailsAreNotFound() { .build(); DispatchQueryContext dispatchQueryContext = DispatchQueryContext.builder() + .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) .build(); @@ -103,28 +95,7 @@ public void testWhenIndexDetailsAreNotFound() { @Test public void testWhenIndexDetailsWithInvalidQueryActionType() { FlintIndexMetadata flintIndexMetadata = mock(FlintIndexMetadata.class); - IndexDMLHandler indexDMLHandler = - new IndexDMLHandler( - jobExecutionResponseReader, - flintIndexMetadataService, - indexDMLResultStorageService, - flintIndexOpFactory); - DispatchQueryRequest dispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "CREATE INDEX", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); - DataSourceMetadata metadata = - new DataSourceMetadata.Builder() - .setName("mys3") - .setDescription("test description") - .setConnector(DataSourceType.S3GLUE) - .setDataSourceStatus(ACTIVE) - .build(); + DispatchQueryRequest dispatchQueryRequest = getDispatchQueryRequest("CREATE INDEX"); IndexQueryDetails indexQueryDetails = IndexQueryDetails.builder() .mvName("mys3.default.http_logs_metrics") @@ -133,6 +104,7 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { .build(); DispatchQueryContext dispatchQueryContext = DispatchQueryContext.builder() + .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) .build(); @@ -144,6 +116,17 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); } + private DispatchQueryRequest getDispatchQueryRequest(String query) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier); + } + @Test public void testStaticMethods() { Assertions.assertTrue(IndexDMLHandler.isIndexDMLQuery("dropIndexJobId")); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a22ce7f460..5d04c86cce 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -18,7 +18,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -57,7 +56,6 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; @@ -92,6 +90,7 @@ public class SparkQueryDispatcherTest { @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private QueryIdProvider queryIdProvider; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -101,7 +100,7 @@ public class SparkQueryDispatcherTest { private SparkQueryDispatcher sparkQueryDispatcher; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + private final String QUERY_ID = "QUERY_ID"; @Captor ArgumentCaptor startJobRequestArgumentCaptor; @@ -117,8 +116,8 @@ void setUp() { flintIndexOpFactory, emrServerlessClientFactory); sparkQueryDispatcher = - new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); - new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); } @Test @@ -834,7 +833,7 @@ void testCancelJob() { String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -897,7 +896,7 @@ void testCancelQueryWithNoSessionId() { String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -1224,7 +1223,7 @@ private AsyncQueryJobMetadata asyncQueryJobMetadata() { private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( String statementId, String sessionId) { return AsyncQueryJobMetadata.builder() - .queryId(new AsyncQueryId(statementId)) + .queryId(statementId) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .sessionId(sessionId) diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index e3f610000c..357a09c3ee 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -371,7 +371,7 @@ public TestStatement run() { private QueryRequest queryRequest() { return new QueryRequest( - AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME), LangType.SQL, "select 1"); + AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME).getId(), LangType.SQL, "select 1"); } private Statement createStatement(StatementId stId) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java index cf658ea017..f0cce5405c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java @@ -16,7 +16,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -29,7 +28,7 @@ class AsyncQueryJobMetadataXContentSerializerTest { void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = AsyncQueryJobMetadata.builder() - .queryId(new AsyncQueryId("query1")) + .queryId("query1") .applicationId("app1") .jobId("job1") .resultIndex("result1") @@ -72,7 +71,7 @@ void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); - assertEquals("query1", jobMetadata.getQueryId().getId()); + assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex()); @@ -142,7 +141,7 @@ void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); - assertEquals("query1", jobMetadata.getQueryId().getId()); + assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex());