Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SparkQueryDispatcher #2636

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.dispatcher;

import lombok.RequiredArgsConstructor;
import org.opensearch.client.Client;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

@RequiredArgsConstructor
public class QueryHandlerFactory {

private final JobExecutionResponseReader jobExecutionResponseReader;
private final FlintIndexMetadataService flintIndexMetadataService;
private final Client client;
private final SessionManager sessionManager;
private final LeaseManager leaseManager;
private final StateStore stateStore;

public RefreshQueryHandler getRefreshQueryHandler(EMRServerlessClient emrServerlessClient) {
return new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}

public StreamingQueryHandler getStreamingQueryHandler(EMRServerlessClient emrServerlessClient) {
return new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
}

public BatchQueryHandler getBatchQueryHandler(EMRServerlessClient emrServerlessClient) {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
}

public InteractiveQueryHandler getInteractiveQueryHandler() {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
}

public IndexDMLHandler getIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
Expand All @@ -23,10 +23,6 @@
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;

Expand All @@ -39,65 +35,68 @@ public class SparkQueryDispatcher {
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private EMRServerlessClientFactory emrServerlessClientFactory;

private DataSourceService dataSourceService;

private JobExecutionResponseReader jobExecutionResponseReader;

private FlintIndexMetadataService flintIndexMetadataService;

private Client client;

private SessionManager sessionManager;

private LeaseManager leaseManager;

private StateStore stateStore;
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final DataSourceService dataSourceService;
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
AsyncQueryHandler asyncQueryHandler =
sessionManager.isEnabled()
? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
: new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
DispatchQueryContext.DispatchQueryContextBuilder contextBuilder =
DispatchQueryContext.builder()

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails, emrServerlessClient)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.build();
return getDefaultAsyncQueryHandler(emrServerlessClient).submit(dispatchQueryRequest, context);
}
}

private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));
}

// override asyncQueryHandler with specific.
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillMissingDetails(dispatchQueryRequest, indexQueryDetails);
contextBuilder.indexQueryDetails(indexQueryDetails);

if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
asyncQueryHandler = createIndexDMLHandler(emrServerlessClient);
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
asyncQueryHandler =
new StreamingQueryHandler(
emrServerlessClient, jobExecutionResponseReader, leaseManager);
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
asyncQueryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}
private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(IndexQueryDetails indexQueryDetails, EMRServerlessClient emrServerlessClient) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient);
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient);
vmmusings marked this conversation as resolved.
Show resolved Hide resolved
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient);
} else {
return getDefaultAsyncQueryHandler(emrServerlessClient);
}
return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build());
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler(EMRServerlessClient emrServerlessClient) {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler(emrServerlessClient);
}

@NotNull
private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) {
IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillDatasourceName(dispatchQueryRequest, indexQueryDetails);
return indexQueryDetails;
}


private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) {
Boolean isCreateAutoRefreshIndex =
IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())
Expand All @@ -119,58 +118,33 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail
}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata);
} else {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
}
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).getQueryResponse(asyncQueryJobMetadata);
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).cancelJob(asyncQueryJobMetadata);
}

private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
AsyncQueryHandler queryHandler;
if (asyncQueryJobMetadata.getSessionId() != null) {
queryHandler =
new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
queryHandler = createIndexDMLHandler(emrServerlessClient);
return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient);
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
queryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient);
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
queryHandler =
new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient);
} else {
queryHandler =
new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getBatchQueryHandler(emrServerlessClient);
}
return queryHandler.cancelJob(asyncQueryJobMetadata);
}

private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}

// TODO: Revisit this logic.
// Currently, Spark if datasource is not provided in query.
// Spark Assumes the datasource to be catalog.
// This is required to handle drop index case properly when datasource name is not provided.
private static void fillMissingDetails(
private static void fillDatasourceName(
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (indexQueryDetails.getFullyQualifiedTableName() != null
&& indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand All @@ -36,7 +37,8 @@
public class AsyncExecutorServiceModule extends AbstractModule {

@Override
protected void configure() {}
protected void configure() {
}

@Provides
public AsyncQueryExecutorService asyncQueryExecutorService(
Expand Down Expand Up @@ -67,15 +69,27 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) {
public SparkQueryDispatcher sparkQueryDispatcher(
EMRServerlessClientFactory emrServerlessClientFactory,
DataSourceService dataSourceService,
SessionManager sessionManager,
QueryHandlerFactory queryHandlerFactory
) {
return new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
sessionManager,
queryHandlerFactory
);
}

@Provides
public QueryHandlerFactory queryhandlerFactory(
JobExecutionResponseReader jobExecutionResponseReader,
FlintIndexMetadataServiceImpl flintIndexMetadataReader,
NodeClient client,
SessionManager sessionManager,
DefaultLeaseManager defaultLeaseManager,
StateStore stateStore) {
return new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
StateStore stateStore
) {
return new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.session.SessionModel;
Expand Down Expand Up @@ -200,16 +201,20 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
StateStore stateStore = new StateStore(client, clusterService);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore
);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
this.dataSourceService,
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
queryHandlerFactory);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,20 @@ public class SparkQueryDispatcherTest {

@BeforeEach
void setUp() {
QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataService,
openSearchClient,
sessionManager,
leaseManager,
stateStore
);
sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
jobExecutionResponseReader,
flintIndexMetadataService,
openSearchClient,
sessionManager,
leaseManager,
stateStore);
queryHandlerFactory);
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
}

Expand Down
Loading