Skip to content

Commit

Permalink
Refactor async executor service dependencies using guice framework
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vmmusings committed Feb 1, 2024
1 parent e59bf75 commit 6cee349
Show file tree
Hide file tree
Showing 19 changed files with 547 additions and 209 deletions.
116 changes: 21 additions & 95 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -87,26 +79,13 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction;
import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;
Expand All @@ -127,7 +106,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
Expand Down Expand Up @@ -223,32 +201,16 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
b -> {
b.bind(NodeClient.class).toInstance((NodeClient) client);
b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings);
b.bind(DataSourceService.class).toInstance(dataSourceService);
b.bind(ClusterService.class).toInstance(clusterService);
});

modules.add(new AsyncExecutorServiceModule());
injector = modules.createInjector();
ClusterManagerEventListener clusterManagerEventListener =
new ClusterManagerEventListener(
Expand All @@ -261,12 +223,15 @@ public Collection<Object> createComponents(
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
dataSourceService,
injector.getInstance(AsyncQueryExecutorService.class),
clusterManagerEventListener,
pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -319,56 +284,17 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
private AsyncQueryExecutorService createAsyncExecutorService() {
ModulesBuilder modulesBuilder = new ModulesBuilder();
modulesBuilder.add(
b -> {
b.bind(NodeClient.class).toInstance(client);
b.bind(org.opensearch.sql.common.setting.Settings.class)
.toInstance(new OpenSearchSettings(clusterService.getClusterSettings()));
b.bind(DataSourceService.class).toInstance(dataSourceService);
});
modulesBuilder.add(new AsyncExecutorServiceModule());
this.injector = modulesBuilder.createInjector();
return this.injector.getInstance(AsyncQueryExecutorService.class);
}
}
3 changes: 2 additions & 1 deletion spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ jacocoTestCoverageVerification {
// TODO: add tests for purging flint indices
'org.opensearch.sql.spark.cluster.ClusterManagerEventListener*',
'org.opensearch.sql.spark.cluster.FlintIndexRetention',
'org.opensearch.sql.spark.cluster.IndexCleanup'
'org.opensearch.sql.spark.cluster.IndexCleanup',
'org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.sql.spark.asyncquery;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;

Expand Down Expand Up @@ -34,26 +33,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService
private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
private SparkQueryDispatcher sparkQueryDispatcher;
private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private Boolean isSparkJobExecutionEnabled;

public AsyncQueryExecutorServiceImpl() {
this.isSparkJobExecutionEnabled = Boolean.FALSE;
}

public AsyncQueryExecutorServiceImpl(
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService,
SparkQueryDispatcher sparkQueryDispatcher,
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) {
this.isSparkJobExecutionEnabled = Boolean.TRUE;
this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService;
this.sparkQueryDispatcher = sparkQueryDispatcher;
this.sparkExecutionEngineConfigSupplier = sparkExecutionEngineConfigSupplier;
}

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest) {
validateSparkExecutionEngineSettings();
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -80,7 +63,6 @@ public CreateAsyncQueryResponse createAsyncQuery(

@Override
public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
validateSparkExecutionEngineSettings();
Optional<AsyncQueryJobMetadata> jobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (jobMetadata.isPresent()) {
Expand Down Expand Up @@ -120,14 +102,4 @@ public String cancelQuery(String queryId) {
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}

private void validateSparkExecutionEngineSettings() {
if (!isSparkJobExecutionEnabled) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.opensearch.sql.spark.client;

/** Factory interface for creating instances of {@link EMRServerlessClient}. */
public interface EMRServerlessClientFactory {

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.sql.spark.client;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
this.emrServerlessClient = createEMRServerlessClient(this.region);
}
return this.emrServerlessClient;
}

private boolean isNewClientCreationRequired(String region) {
return !region.equals(this.region);
}

private void validateSparkExecutionEngineConfig(
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Loading

0 comments on commit 6cee349

Please sign in to comment.