Skip to content

Commit

Permalink
Refactor SparkQueryDispatcher (#2636)
Browse files Browse the repository at this point in the history
* Refactor SparkQueryDispatcher

Signed-off-by: Tomoyuki Morita <[email protected]>

* Remove EMRServerlessClientFactory from SparkQueryDispatcher

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix unit test failures in SparkQueryDispatcherTest

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
(cherry picked from commit d32cf94)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Apr 30, 2024
1 parent 294566f commit f40e94a
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 112 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.dispatcher;

import lombok.RequiredArgsConstructor;
import org.opensearch.client.Client;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

@RequiredArgsConstructor
public class QueryHandlerFactory {

private final JobExecutionResponseReader jobExecutionResponseReader;
private final FlintIndexMetadataService flintIndexMetadataService;
private final Client client;
private final SessionManager sessionManager;
private final LeaseManager leaseManager;
private final StateStore stateStore;
private final EMRServerlessClientFactory emrServerlessClientFactory;

public RefreshQueryHandler getRefreshQueryHandler() {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}

public StreamingQueryHandler getStreamingQueryHandler() {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public BatchQueryHandler getBatchQueryHandler() {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public InteractiveQueryHandler getInteractiveQueryHandler() {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
}

public IndexDMLHandler getIndexDMLHandler() {
return new IndexDMLHandler(
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,19 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;

Expand All @@ -39,63 +33,67 @@ public class SparkQueryDispatcher {
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private EMRServerlessClientFactory emrServerlessClientFactory;

private DataSourceService dataSourceService;

private JobExecutionResponseReader jobExecutionResponseReader;

private FlintIndexMetadataService flintIndexMetadataService;

private Client client;

private SessionManager sessionManager;

private LeaseManager leaseManager;

private StateStore stateStore;
private final DataSourceService dataSourceService;
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
AsyncQueryHandler asyncQueryHandler =
sessionManager.isEnabled()
? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
: new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
DispatchQueryContext.DispatchQueryContextBuilder contextBuilder =
DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));

// override asyncQueryHandler with specific.

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillMissingDetails(dispatchQueryRequest, indexQueryDetails);
contextBuilder.indexQueryDetails(indexQueryDetails);

if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
asyncQueryHandler = createIndexDMLHandler(emrServerlessClient);
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
asyncQueryHandler =
new StreamingQueryHandler(
emrServerlessClient, jobExecutionResponseReader, leaseManager);
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
asyncQueryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build());
}

private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
} else {
return getDefaultAsyncQueryHandler();
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
}

@NotNull
private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) {
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillDatasourceName(dispatchQueryRequest, indexQueryDetails);
return indexQueryDetails;
}

private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) {
Expand All @@ -119,58 +117,35 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail
}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata);
} else {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
}
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.getQueryResponse(asyncQueryJobMetadata);
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
AsyncQueryHandler queryHandler;
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.cancelJob(asyncQueryJobMetadata);
}

private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
AsyncQueryJobMetadata asyncQueryJobMetadata) {
if (asyncQueryJobMetadata.getSessionId() != null) {
queryHandler =
new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
queryHandler = createIndexDMLHandler(emrServerlessClient);
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
queryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
return queryHandlerFactory.getRefreshQueryHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
queryHandler =
new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getStreamingQueryHandler();
} else {
queryHandler =
new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getBatchQueryHandler();
}
return queryHandler.cancelJob(asyncQueryJobMetadata);
}

private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}

// TODO: Revisit this logic.
// Currently, Spark if datasource is not provided in query.
// Spark Assumes the datasource to be catalog.
// This is required to handle drop index case properly when datasource name is not provided.
private static void fillMissingDetails(
private static void fillDatasourceName(
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (indexQueryDetails.getFullyQualifiedTableName() != null
&& indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand Down Expand Up @@ -65,23 +66,29 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) {

@Provides
public SparkQueryDispatcher sparkQueryDispatcher(
EMRServerlessClientFactory emrServerlessClientFactory,
DataSourceService dataSourceService,
SessionManager sessionManager,
QueryHandlerFactory queryHandlerFactory) {
return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory);
}

@Provides
public QueryHandlerFactory queryhandlerFactory(
JobExecutionResponseReader jobExecutionResponseReader,
FlintIndexMetadataServiceImpl flintIndexMetadataReader,
NodeClient client,
SessionManager sessionManager,
DefaultLeaseManager defaultLeaseManager,
StateStore stateStore) {
return new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
StateStore stateStore,
EMRServerlessClientFactory emrServerlessClientFactory) {
return new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
sessionManager,
defaultLeaseManager,
stateStore);
stateStore,
emrServerlessClientFactory);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.session.SessionModel;
Expand Down Expand Up @@ -200,16 +201,20 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
StateStore stateStore = new StateStore(client, clusterService);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
this.dataSourceService,
QueryHandlerFactory queryHandlerFactory =
new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
stateStore,
emrServerlessClientFactory);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
this.dataSourceService,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
queryHandlerFactory);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
Expand Down
Loading

0 comments on commit f40e94a

Please sign in to comment.