From 03a5e4dc828593eb111df47a4d3636ddceb507c2 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 4 Jun 2024 13:39:15 -0700 Subject: [PATCH] Abstract queryId generation (#2695) * Abstract queryId generation Signed-off-by: Tomoyuki Morita * Remove OpenSearch specific id mapping from model classes Signed-off-by: Tomoyuki Morita * Fix code style Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita --- .../AsyncQueryExecutorServiceImpl.java | 2 +- ...hAsyncQueryJobMetadataStorageService.java} | 31 ++++---- .../model/AsyncQueryJobMetadata.java | 10 ++- .../spark/dispatcher/BatchQueryHandler.java | 11 ++- .../DatasourceEmbeddedQueryIdProvider.java | 18 +++++ .../sql/spark/dispatcher/IndexDMLHandler.java | 36 ++++++---- .../dispatcher/InteractiveQueryHandler.java | 19 ++--- .../sql/spark/dispatcher/QueryIdProvider.java | 13 ++++ .../spark/dispatcher/RefreshQueryHandler.java | 19 ++--- .../dispatcher/SparkQueryDispatcher.java | 6 +- .../dispatcher/StreamingQueryHandler.java | 17 +++-- .../model/DispatchQueryContext.java | 3 +- .../model/DispatchQueryResponse.java | 27 +------ .../dispatcher/model/IndexDMLResult.java | 4 +- .../execution/session/InteractiveSession.java | 2 +- .../execution/statement/QueryRequest.java | 3 +- .../OpenSearchSessionStorageService.java | 1 + .../OpenSearchStatementStorageService.java | 1 + .../execution/statestore/StateStore.java | 26 +------ ...yncQueryJobMetadataXContentSerializer.java | 5 +- ...OpenSearchFlintIndexStateModelService.java | 1 + ...penSearchIndexDMLResultStorageService.java | 10 ++- .../config/AsyncExecutorServiceModule.java | 17 +++-- .../AsyncQueryExecutorServiceImplTest.java | 21 ++++-- .../AsyncQueryExecutorServiceSpec.java | 6 +- ...ncQueryJobMetadataStorageServiceTest.java} | 29 ++++---- .../spark/dispatcher/IndexDMLHandlerTest.java | 71 +++++++------------ .../dispatcher/SparkQueryDispatcherTest.java | 15 ++-- .../execution/statement/StatementTest.java | 2 +- ...ueryJobMetadataXContentSerializerTest.java | 7 +- ...SearchFlintIndexStateModelServiceTest.java | 3 +- 31 files changed, 227 insertions(+), 209 deletions(-) rename spark/src/main/java/org/opensearch/sql/spark/asyncquery/{OpensearchAsyncQueryJobMetadataStorageService.java => OpenSearchAsyncQueryJobMetadataStorageService.java} (67%) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java rename spark/src/test/java/org/opensearch/sql/spark/asyncquery/{OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java => OpenSearchAsyncQueryJobMetadataStorageServiceTest.java} (76%) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index e4818d737c..14107712f1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -63,7 +63,7 @@ public CreateAsyncQueryResponse createAsyncQuery( .indexName(dispatchQueryResponse.getIndexName()) .build()); return new CreateAsyncQueryResponse( - dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); + dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java similarity index 67% rename from spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java rename to spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java index 2ac67b96ba..5356f14143 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.sql.spark.asyncquery; @@ -12,43 +10,46 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; +import org.opensearch.sql.spark.utils.IDUtils; -/** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ +/** OpenSearch implementation of {@link AsyncQueryJobMetadataStorageService} */ @RequiredArgsConstructor -public class OpensearchAsyncQueryJobMetadataStorageService +public class OpenSearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { private final StateStore stateStore; private final AsyncQueryJobMetadataXContentSerializer asyncQueryJobMetadataXContentSerializer; private static final Logger LOGGER = - LogManager.getLogger(OpensearchAsyncQueryJobMetadataStorageService.class); + LogManager.getLogger(OpenSearchAsyncQueryJobMetadataStorageService.class); @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); stateStore.create( + mapIdToDocumentId(asyncQueryJobMetadata.getId()), asyncQueryJobMetadata, AsyncQueryJobMetadata::copy, - OpenSearchStateStoreUtil.getIndexName(queryId.getDataSourceName())); + OpenSearchStateStoreUtil.getIndexName(asyncQueryJobMetadata.getDatasourceName())); + } + + private String mapIdToDocumentId(String id) { + return "qid" + id; } @Override - public Optional getJobMetadata(String qid) { + public Optional getJobMetadata(String queryId) { try { - AsyncQueryId queryId = new AsyncQueryId(qid); return stateStore.get( - queryId.docId(), + mapIdToDocumentId(queryId), asyncQueryJobMetadataXContentSerializer::fromXContent, - OpenSearchStateStoreUtil.getIndexName(queryId.getDataSourceName())); + OpenSearchStateStoreUtil.getIndexName(IDUtils.decode(queryId))); } catch (Exception e) { LOGGER.error("Error while fetching the job metadata.", e); - throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid)); + throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", queryId)); } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 08770c7588..e1f30edc10 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.sql.spark.asyncquery.model; @@ -21,7 +19,7 @@ @SuperBuilder @EqualsAndHashCode(callSuper = false) public class AsyncQueryJobMetadata extends StateModel { - private final AsyncQueryId queryId; + private final String queryId; private final String applicationId; private final String jobId; private final String resultIndex; @@ -59,6 +57,6 @@ public static AsyncQueryJobMetadata copy( @Override public String getId() { - return queryId.docId(); + return queryId; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 8d3803045b..3bdbd8ca74 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -63,7 +63,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { emrServerlessClient.cancelJobRun( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false); - return asyncQueryJobMetadata.getQueryId().getId(); + return asyncQueryJobMetadata.getQueryId(); } @Override @@ -93,7 +93,12 @@ public DispatchQueryResponse submit( dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); MetricUtils.incrementNumericalMetric(MetricName.EMR_BATCH_QUERY_JOBS_CREATION_COUNT); - return new DispatchQueryResponse( - context.getQueryId(), jobId, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(jobId) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java new file mode 100644 index 0000000000..c170040718 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.utils.IDUtils; + +/** Generates QueryId by embedding Datasource name and random UUID */ +public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider { + + @Override + public String getQueryId(DispatchQueryRequest dispatchQueryRequest) { + return IDUtils.encode(dispatchQueryRequest.getDatasource()); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 72980dcb1f..199f24977c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -16,13 +16,13 @@ import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -65,39 +65,51 @@ public DispatchQueryResponse submit( getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); - AsyncQueryId asyncQueryId = + String asyncQueryId = storeIndexDMLResult( + context.getQueryId(), dispatchQueryRequest, dataSourceMetadata, JobRunState.SUCCESS.toString(), StringUtils.EMPTY, getElapsedTimeSince(startTime)); - return new DispatchQueryResponse( - asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(asyncQueryId) + .jobId(DML_QUERY_JOB_ID) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } catch (Exception e) { LOG.error(e.getMessage()); - AsyncQueryId asyncQueryId = + String asyncQueryId = storeIndexDMLResult( + context.getQueryId(), dispatchQueryRequest, dataSourceMetadata, JobRunState.FAILED.toString(), e.getMessage(), getElapsedTimeSince(startTime)); - return new DispatchQueryResponse( - asyncQueryId, DML_QUERY_JOB_ID, dataSourceMetadata.getResultIndex(), null); + return DispatchQueryResponse.builder() + .queryId(asyncQueryId) + .jobId(DML_QUERY_JOB_ID) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } } - private AsyncQueryId storeIndexDMLResult( + private String storeIndexDMLResult( + String queryId, DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata, String status, String error, long queryRunTime) { - AsyncQueryId asyncQueryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); IndexDMLResult indexDMLResult = IndexDMLResult.builder() - .queryId(asyncQueryId.getId()) + .queryId(queryId) .status(status) .error(error) .datasourceName(dispatchQueryRequest.getDatasource()) @@ -105,7 +117,7 @@ private AsyncQueryId storeIndexDMLResult( .updateTime(System.currentTimeMillis()) .build(); indexDMLResultStorageService.createIndexDMLResult(indexDMLResult); - return asyncQueryId; + return queryId; } private long getElapsedTimeSince(long startTime) { @@ -143,7 +155,7 @@ private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 552ddeb76e..e41f4a49fd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -49,7 +49,7 @@ public class InteractiveQueryHandler extends AsyncQueryHandler { @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } @@ -57,7 +57,7 @@ protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQuery @Override protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { JSONObject result = new JSONObject(); - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); StatementState statementState = statement.getStatementState(); result.put(STATUS_FIELD, statementState.getState()); @@ -67,7 +67,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob @Override public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); + String queryId = asyncQueryJobMetadata.getQueryId(); getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); return queryId; } @@ -118,11 +118,14 @@ public DispatchQueryResponse submit( context.getQueryId(), dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); - return new DispatchQueryResponse( - context.getQueryId(), - session.getSessionModel().getJobId(), - dataSourceMetadata.getResultIndex(), - session.getSessionId().getSessionId()); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(session.getSessionModel().getJobId()) + .resultIndex(dataSourceMetadata.getResultIndex()) + .sessionId(session.getSessionId().getSessionId()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.INTERACTIVE) + .build(); } private Statement getStatementByQueryId(String sid, String qid) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java new file mode 100644 index 0000000000..2167eb6b7a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Interface for extension point to specify queryId. Called when new query is executed. */ +public interface QueryIdProvider { + String getQueryId(DispatchQueryRequest dispatchQueryRequest); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index edb0a3f507..69c21321a6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -53,7 +53,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); jobCancelOp.apply(indexMetadata); - return asyncQueryJobMetadata.getQueryId().getId(); + return asyncQueryJobMetadata.getQueryId(); } @Override @@ -61,13 +61,14 @@ public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); - return new DispatchQueryResponse( - resp.getQueryId(), - resp.getJobId(), - resp.getResultIndex(), - resp.getSessionId(), - dataSourceMetadata.getName(), - JobType.BATCH, - context.getIndexQueryDetails().openSearchIndexName()); + return DispatchQueryResponse.builder() + .queryId(resp.getQueryId()) + .jobId(resp.getJobId()) + .resultIndex(resp.getResultIndex()) + .sessionId(resp.getSessionId()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.BATCH) + .indexName(context.getIndexQueryDetails().openSearchIndexName()) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index b6f5bcceb3..67d2767493 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -12,7 +12,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -36,6 +35,7 @@ public class SparkQueryDispatcher { private final DataSourceService dataSourceService; private final SessionManager sessionManager; private final QueryHandlerFactory queryHandlerFactory; + private final QueryIdProvider queryIdProvider; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = @@ -59,12 +59,12 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } } - private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( + private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + .queryId(queryIdProvider.getQueryId(dispatchQueryRequest)); } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 886e7d176a..0649e81418 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,7 +12,6 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.utils.MetricUtils; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -82,13 +81,13 @@ public DispatchQueryResponse submit( dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); MetricUtils.incrementNumericalMetric(MetricName.EMR_STREAMING_QUERY_JOBS_CREATION_COUNT); - return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - jobId, - dataSourceMetadata.getResultIndex(), - null, - dataSourceMetadata.getName(), - JobType.STREAMING, - indexQueryDetails.openSearchIndexName()); + return DispatchQueryResponse.builder() + .queryId(context.getQueryId()) + .jobId(jobId) + .resultIndex(dataSourceMetadata.getResultIndex()) + .datasourceName(dataSourceMetadata.getName()) + .jobType(JobType.STREAMING) + .indexName(indexQueryDetails.openSearchIndexName()) + .build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java index d3400d86bf..7b694e47f0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java @@ -9,12 +9,11 @@ import lombok.Builder; import lombok.Getter; import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Getter @Builder public class DispatchQueryContext { - private final AsyncQueryId queryId; + private final String queryId; private final DataSourceMetadata dataSourceMetadata; private final Map tags; private final IndexQueryDetails indexQueryDetails; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 2c39aab1d4..b97d9fd7b0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -1,37 +1,16 @@ package org.opensearch.sql.spark.dispatcher.model; +import lombok.Builder; import lombok.Getter; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Getter +@Builder public class DispatchQueryResponse { - private final AsyncQueryId queryId; + private final String queryId; private final String jobId; private final String resultIndex; private final String sessionId; private final String datasourceName; private final JobType jobType; private final String indexName; - - public DispatchQueryResponse( - AsyncQueryId queryId, String jobId, String resultIndex, String sessionId) { - this(queryId, jobId, resultIndex, sessionId, null, JobType.INTERACTIVE, null); - } - - public DispatchQueryResponse( - AsyncQueryId queryId, - String jobId, - String resultIndex, - String sessionId, - String datasourceName, - JobType jobType, - String indexName) { - this.queryId = queryId; - this.jobId = jobId; - this.resultIndex = resultIndex; - this.sessionId = sessionId; - this.datasourceName = datasourceName; - this.jobType = jobType; - this.indexName = indexName; - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java index 42bddf6c15..a276076f4b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java @@ -16,8 +16,6 @@ @SuperBuilder @EqualsAndHashCode(callSuper = false) public class IndexDMLResult extends StateModel { - public static final String DOC_ID_PREFIX = "index"; - private final String queryId; private final String status; private final String error; @@ -39,6 +37,6 @@ public static IndexDMLResult copy(IndexDMLResult copy, ImmutableMap T create(T st, CopyBuilder builder, String indexName) { + public T create( + String docId, T st, CopyBuilder builder, String indexName) { try { if (!this.clusterService.state().routingTable().hasIndex(indexName)) { createIndex(indexName); @@ -86,7 +86,7 @@ public T create(T st, CopyBuilder builder, String inde XContentSerializer serializer = getXContentSerializer(st); IndexRequest indexRequest = new IndexRequest(indexName) - .id(st.getId()) + .id(docId) .source(serializer.toXContent(st, ToXContent.EMPTY_PARAMS)) .setIfSeqNo(getSeqNo(st)) .setIfPrimaryTerm(getPrimaryTerm(st)) @@ -268,26 +268,6 @@ private String loadConfigFromResource(String fileName) throws IOException { return IOUtils.toString(fileStream, StandardCharsets.UTF_8); } - public static Function createJobMetaData( - StateStore stateStore, String datasourceName) { - return (jobMetadata) -> - stateStore.create( - jobMetadata, - AsyncQueryJobMetadata::copy, - OpenSearchStateStoreUtil.getIndexName(datasourceName)); - } - - public static Function> getJobMetaData( - StateStore stateStore, String datasourceName) { - AsyncQueryJobMetadataXContentSerializer asyncQueryJobMetadataXContentSerializer = - new AsyncQueryJobMetadataXContentSerializer(); - return (docId) -> - stateStore.get( - docId, - asyncQueryJobMetadataXContentSerializer::fromXContent, - OpenSearchStateStoreUtil.getIndexName(datasourceName)); - } - public static Supplier activeSessionsCount(StateStore stateStore, String datasourceName) { return () -> stateStore.count( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java index a4209a0ce7..39a1ec83e4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java @@ -20,7 +20,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -37,7 +36,7 @@ public XContentBuilder toXContent(AsyncQueryJobMetadata jobMetadata, ToXContent. throws IOException { return XContentFactory.jsonBuilder() .startObject() - .field(QUERY_ID, jobMetadata.getQueryId().getId()) + .field(QUERY_ID, jobMetadata.getQueryId()) .field(TYPE, TYPE_JOBMETA) .field(JOB_ID, jobMetadata.getJobId()) .field(APPLICATION_ID, jobMetadata.getApplicationId()) @@ -59,7 +58,7 @@ public AsyncQueryJobMetadata fromXContent(XContentParser parser, long seqNo, lon parser.nextToken(); switch (fieldName) { case QUERY_ID: - builder.queryId(new AsyncQueryId(parser.textOrNull())); + builder.queryId(parser.textOrNull()); break; case JOB_ID: builder.jobId(parser.textOrNull()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java index 2650ff3cb3..5781c3e44b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -38,6 +38,7 @@ public Optional getFlintIndexStateModel(String id, String public FlintIndexStateModel createFlintIndexStateModel( FlintIndexStateModel flintIndexStateModel) { return stateStore.create( + flintIndexStateModel.getId(), flintIndexStateModel, FlintIndexStateModel::copy, OpenSearchStateStoreUtil.getIndexName(flintIndexStateModel.getDatasourceName())); diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java index 314368771f..f5a1f70d1c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java @@ -21,6 +21,14 @@ public class OpenSearchIndexDMLResultStorageService implements IndexDMLResultSto public IndexDMLResult createIndexDMLResult(IndexDMLResult result) { DataSourceMetadata dataSourceMetadata = dataSourceService.getDataSourceMetadata(result.getDatasourceName()); - return stateStore.create(result, IndexDMLResult::copy, dataSourceMetadata.getResultIndex()); + return stateStore.create( + mapIdToDocumentId(result.getId()), + result, + IndexDMLResult::copy, + dataSourceMetadata.getResultIndex()); + } + + private String mapIdToDocumentId(String id) { + return "index" + id; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 615a914fee..ca252f48c6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -20,12 +20,14 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpenSearchAsyncQueryJobMetadataStorageService; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; +import org.opensearch.sql.spark.dispatcher.QueryIdProvider; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.OpenSearchSessionStorageService; @@ -67,7 +69,7 @@ public AsyncQueryExecutorService asyncQueryExecutorService( @Provides public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService( StateStore stateStore, AsyncQueryJobMetadataXContentSerializer serializer) { - return new OpensearchAsyncQueryJobMetadataStorageService(stateStore, serializer); + return new OpenSearchAsyncQueryJobMetadataStorageService(stateStore, serializer); } @Provides @@ -82,8 +84,15 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) { public SparkQueryDispatcher sparkQueryDispatcher( DataSourceService dataSourceService, SessionManager sessionManager, - QueryHandlerFactory queryHandlerFactory) { - return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + QueryHandlerFactory queryHandlerFactory, + QueryIdProvider queryIdProvider) { + return new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + } + + @Provides + public QueryIdProvider queryIdProvider() { + return new DatasourceEmbeddedQueryIdProvider(); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 2b84f967f0..43dd4880e7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,7 +11,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -31,7 +30,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.RequestContext; import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; @@ -41,6 +39,7 @@ import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; @@ -55,7 +54,7 @@ public class AsyncQueryExecutorServiceImplTest { @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Mock private RequestContext requestContext; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + private final String QUERY_ID = "QUERY_ID"; @BeforeEach void setUp() { @@ -89,7 +88,12 @@ void testCreateAsyncQuery() { TEST_CLUSTER_NAME, sparkSubmitParameterModifier); when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); + .thenReturn( + DispatchQueryResponse.builder() + .queryId(QUERY_ID) + .jobId(EMR_JOB_ID) + .jobType(JobType.INTERACTIVE) + .build()); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest, requestContext); @@ -99,7 +103,7 @@ void testCreateAsyncQuery() { verify(sparkExecutionEngineConfigSupplier, times(1)) .getSparkExecutionEngineConfig(requestContext); verify(sparkQueryDispatcher, times(1)).dispatch(expectedDispatchQueryRequest); - Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId()); } @Test @@ -115,7 +119,12 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { modifier, TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); + .thenReturn( + DispatchQueryResponse.builder() + .queryId(QUERY_ID) + .jobId(EMR_JOB_ID) + .jobType(JobType.INTERACTIVE) + .build()); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index b15a911364..90a06edb19 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -58,6 +58,7 @@ import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; @@ -235,7 +236,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService( + new OpenSearchAsyncQueryJobMetadataStorageService( stateStore, new AsyncQueryJobMetadataXContentSerializer()); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( @@ -262,7 +263,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( statementStorageService, emrServerlessClientFactory, pluginSettings), - queryHandlerFactory); + queryHandlerFactory, + new DatasourceEmbeddedQueryIdProvider()); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java similarity index 76% rename from spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java rename to spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java index 431f5b2b15..a0baaefab8 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java @@ -13,25 +13,24 @@ import org.junit.Test; import org.junit.jupiter.api.Assertions; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; +import org.opensearch.sql.spark.utils.IDUtils; import org.opensearch.test.OpenSearchIntegTestCase; -public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest - extends OpenSearchIntegTestCase { +public class OpenSearchAsyncQueryJobMetadataStorageServiceTest extends OpenSearchIntegTestCase { public static final String DS_NAME = "mys3"; private static final String MOCK_SESSION_ID = "sessionId"; private static final String MOCK_RESULT_INDEX = "resultIndex"; private static final String MOCK_QUERY_ID = "00fdo6u94n7abo0q"; - private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; + private OpenSearchAsyncQueryJobMetadataStorageService openSearchJobMetadataStorageService; @Before public void setup() { - opensearchJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService( + openSearchJobMetadataStorageService = + new OpenSearchAsyncQueryJobMetadataStorageService( new StateStore(client(), clusterService()), new AsyncQueryJobMetadataXContentSerializer()); } @@ -40,15 +39,16 @@ public void setup() { public void testStoreJobMetadata() { AsyncQueryJobMetadata expected = AsyncQueryJobMetadata.builder() - .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .queryId(IDUtils.encode(DS_NAME)) .jobId(EMR_JOB_ID) .applicationId(EMRS_APPLICATION_ID) .resultIndex(MOCK_RESULT_INDEX) + .datasourceName(DS_NAME) .build(); - opensearchJobMetadataStorageService.storeJobMetadata(expected); + openSearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + openSearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); assertTrue(actual.isPresent()); assertEquals(expected, actual.get()); @@ -60,16 +60,17 @@ public void testStoreJobMetadata() { public void testStoreJobMetadataWithResultExtraData() { AsyncQueryJobMetadata expected = AsyncQueryJobMetadata.builder() - .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .queryId(IDUtils.encode(DS_NAME)) .jobId(EMR_JOB_ID) .applicationId(EMRS_APPLICATION_ID) .resultIndex(MOCK_RESULT_INDEX) .sessionId(MOCK_SESSION_ID) + .datasourceName(DS_NAME) .build(); - opensearchJobMetadataStorageService.storeJobMetadata(expected); + openSearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + openSearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); assertTrue(actual.isPresent()); assertEquals(expected, actual.get()); @@ -82,7 +83,7 @@ public void testGetJobMetadataWithMalformedQueryId() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( AsyncQueryNotFoundException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(MOCK_QUERY_ID)); + () -> openSearchJobMetadataStorageService.getJobMetadata(MOCK_QUERY_ID)); Assertions.assertEquals( String.format("Invalid QueryId: %s", MOCK_QUERY_ID), asyncQueryNotFoundException.getMessage()); @@ -93,7 +94,7 @@ public void testGetJobMetadataWithEmptyQueryId() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( AsyncQueryNotFoundException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata("")); + () -> openSearchJobMetadataStorageService.getJobMetadata("")); Assertions.assertEquals("Invalid QueryId: ", asyncQueryNotFoundException.getMessage()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 7d43ccc7e3..2e536ef6b3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; @@ -43,12 +44,23 @@ @ExtendWith(MockitoExtension.class) class IndexDMLHandlerTest { + private static final String QUERY_ID = "QUERY_ID"; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private FlintIndexMetadataService flintIndexMetadataService; @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @InjectMocks IndexDMLHandler indexDMLHandler; + + private static final DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setName("mys3") + .setDescription("test description") + .setConnector(DataSourceType.S3GLUE) + .setDataSourceStatus(ACTIVE) + .build(); + @Test public void getResponseFromExecutor() { JSONObject result = new IndexDMLHandler(null, null, null, null).getResponseFromExecutor(null); @@ -59,28 +71,7 @@ public void getResponseFromExecutor() { @Test public void testWhenIndexDetailsAreNotFound() { - IndexDMLHandler indexDMLHandler = - new IndexDMLHandler( - jobExecutionResponseReader, - flintIndexMetadataService, - indexDMLResultStorageService, - flintIndexOpFactory); - DispatchQueryRequest dispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "DROP INDEX", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); - DataSourceMetadata metadata = - new DataSourceMetadata.Builder() - .setName("mys3") - .setDescription("test description") - .setConnector(DataSourceType.S3GLUE) - .setDataSourceStatus(ACTIVE) - .build(); + DispatchQueryRequest dispatchQueryRequest = getDispatchQueryRequest("DROP INDEX"); IndexQueryDetails indexQueryDetails = IndexQueryDetails.builder() .mvName("mys3.default.http_logs_metrics") @@ -88,6 +79,7 @@ public void testWhenIndexDetailsAreNotFound() { .build(); DispatchQueryContext dispatchQueryContext = DispatchQueryContext.builder() + .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) .build(); @@ -103,28 +95,7 @@ public void testWhenIndexDetailsAreNotFound() { @Test public void testWhenIndexDetailsWithInvalidQueryActionType() { FlintIndexMetadata flintIndexMetadata = mock(FlintIndexMetadata.class); - IndexDMLHandler indexDMLHandler = - new IndexDMLHandler( - jobExecutionResponseReader, - flintIndexMetadataService, - indexDMLResultStorageService, - flintIndexOpFactory); - DispatchQueryRequest dispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "CREATE INDEX", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); - DataSourceMetadata metadata = - new DataSourceMetadata.Builder() - .setName("mys3") - .setDescription("test description") - .setConnector(DataSourceType.S3GLUE) - .setDataSourceStatus(ACTIVE) - .build(); + DispatchQueryRequest dispatchQueryRequest = getDispatchQueryRequest("CREATE INDEX"); IndexQueryDetails indexQueryDetails = IndexQueryDetails.builder() .mvName("mys3.default.http_logs_metrics") @@ -133,6 +104,7 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { .build(); DispatchQueryContext dispatchQueryContext = DispatchQueryContext.builder() + .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) .build(); @@ -144,6 +116,17 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); } + private DispatchQueryRequest getDispatchQueryRequest(String query) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + sparkSubmitParameterModifier); + } + @Test public void testStaticMethods() { Assertions.assertTrue(IndexDMLHandler.isIndexDMLQuery("dropIndexJobId")); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a22ce7f460..5d04c86cce 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -18,7 +18,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -57,7 +56,6 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; @@ -92,6 +90,7 @@ public class SparkQueryDispatcherTest { @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private QueryIdProvider queryIdProvider; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -101,7 +100,7 @@ public class SparkQueryDispatcherTest { private SparkQueryDispatcher sparkQueryDispatcher; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + private final String QUERY_ID = "QUERY_ID"; @Captor ArgumentCaptor startJobRequestArgumentCaptor; @@ -117,8 +116,8 @@ void setUp() { flintIndexOpFactory, emrServerlessClientFactory); sparkQueryDispatcher = - new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); - new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); } @Test @@ -834,7 +833,7 @@ void testCancelJob() { String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -897,7 +896,7 @@ void testCancelQueryWithNoSessionId() { String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -1224,7 +1223,7 @@ private AsyncQueryJobMetadata asyncQueryJobMetadata() { private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( String statementId, String sessionId) { return AsyncQueryJobMetadata.builder() - .queryId(new AsyncQueryId(statementId)) + .queryId(statementId) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .sessionId(sessionId) diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index e3f610000c..357a09c3ee 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -371,7 +371,7 @@ public TestStatement run() { private QueryRequest queryRequest() { return new QueryRequest( - AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME), LangType.SQL, "select 1"); + AsyncQueryId.newAsyncQueryId(TEST_DATASOURCE_NAME).getId(), LangType.SQL, "select 1"); } private Statement createStatement(StatementId stId) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java index cf658ea017..f0cce5405c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java @@ -16,7 +16,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -29,7 +28,7 @@ class AsyncQueryJobMetadataXContentSerializerTest { void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = AsyncQueryJobMetadata.builder() - .queryId(new AsyncQueryId("query1")) + .queryId("query1") .applicationId("app1") .jobId("job1") .resultIndex("result1") @@ -72,7 +71,7 @@ void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); - assertEquals("query1", jobMetadata.getQueryId().getId()); + assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex()); @@ -142,7 +141,7 @@ void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); - assertEquals("query1", jobMetadata.getQueryId().getId()); + assertEquals("query1", jobMetadata.getQueryId()); assertEquals("job1", jobMetadata.getJobId()); assertEquals("app1", jobMetadata.getApplicationId()); assertEquals("result1", jobMetadata.getResultIndex()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java index c9ee5e5ce8..977f77b397 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java @@ -58,7 +58,8 @@ void getFlintIndexStateModel() { @Test void createFlintIndexStateModel() { - when(mockStateStore.create(any(), any(), any())).thenReturn(responseFlintIndexStateModel); + when(mockStateStore.create(any(), any(), any(), any())) + .thenReturn(responseFlintIndexStateModel); when(flintIndexStateModel.getDatasourceName()).thenReturn(DATASOURCE); FlintIndexStateModel result =