From cbb0c141b6bc0486c2a7bcc048b66fd3dea15629 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 8 Sep 2023 17:03:37 -0700 Subject: [PATCH] Create Job API Signed-off-by: Vamsi Manohar --- common/build.gradle | 4 +- .../sql/common/setting/Settings.java | 2 +- .../sql/datasource/DataSourceService.java | 9 + .../sql/executor/ExecutionEngine.java | 1 + .../service/DataSourceServiceImpl.java | 11 ++ integ-test/build.gradle | 1 + .../setting/OpenSearchSettings.java | 13 ++ .../org/opensearch/sql/plugin/SQLPlugin.java | 49 +++++- .../plugin-metadata/plugin-security.policy | 9 + .../sql/protocol/response/QueryResult.java | 3 + spark/build.gradle | 5 +- .../sql/spark/client/EmrClientImpl.java | 15 +- .../sql/spark/client/EmrServerlessClient.java | 17 ++ .../spark/client/EmrServerlessClientImpl.java | 126 ++++++++++++++ .../config/SparkExecutionEngineConfig.java | 21 +++ .../spark/data/constants/SparkConstants.java | 8 +- .../dispatcher/SparkQueryDispatcher.java | 76 ++++++++ ...DefaultSparkSqlFunctionResponseHandle.java | 3 +- .../sql/spark/jobs/JobExecutorService.java | 21 +++ .../spark/jobs/JobExecutorServiceImpl.java | 85 +++++++++ .../spark/jobs/JobMetadataStorageService.java | 11 ++ .../OpensearchJobMetadataStorageService.java | 162 ++++++++++++++++++ .../sql/spark/jobs/model/JobMetadata.java | 87 ++++++++++ ...Response.java => SparkResponseReader.java} | 17 +- .../spark/rest/RestJobManagementAction.java | 2 +- .../spark/rest/model/CreateJobResponse.java | 15 ++ .../spark/storage/SparkStorageFactory.java | 4 +- .../TransportCreateJobRequestAction.java | 28 ++- .../TransportGetQueryResultRequestAction.java | 38 +++- .../resources/job-metadata-index-mapping.yml | 20 +++ .../resources/job-metadata-index-settings.yml | 11 ++ .../sql/spark/client/EmrClientImplTest.java | 16 +- .../client/EmrServerlessClientImplTest.java | 69 ++++++++ .../sql/spark/constants/TestConstants.java | 5 + ...Test.java => SparkResponseReaderTest.java} | 33 ++-- .../TransportCreateJobRequestActionTest.java | 5 +- ...nsportGetQueryResultRequestActionTest.java | 4 +- 37 files changed, 952 insertions(+), 54 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java rename spark/src/main/java/org/opensearch/sql/spark/response/{SparkResponse.java => SparkResponseReader.java} (88%) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java create mode 100644 spark/src/main/resources/job-metadata-index-mapping.yml create mode 100644 spark/src/main/resources/job-metadata-index-settings.yml create mode 100644 spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java rename spark/src/test/java/org/opensearch/sql/spark/response/{SparkResponseTest.java => SparkResponseReaderTest.java} (77%) diff --git a/common/build.gradle b/common/build.gradle index 5cf219fbae..109cad59cb 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -39,8 +39,8 @@ dependencies { api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' - api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1' - api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.545' implementation "com.github.seancfoley:ipaddress:5.4.0" testImplementation group: 'junit', name: 'junit', version: '4.13.2' diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index be780e8d80..8daf0e9bf6 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -35,7 +35,7 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), - + SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), CLUSTER_NAME("cluster.name"); @Getter private final String keyValue; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 3d6ddc864e..54daaffaef 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -39,6 +39,15 @@ public interface DataSourceService { */ DataSourceMetadata getDataSourceMetadata(String name); + /** + * Returns dataSourceMetadata object with specific name. The returned objects contain all the + * metadata information. + * + * @param name name of the {@link DataSource}. + * @return set of {@link DataSourceMetadata}. + */ + DataSourceMetadata getRawDataSourceMetadata(String name); + /** * Register {@link DataSource} defined by {@link DataSourceMetadata}. * diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java index 43b8ccb62e..ffcddfcafd 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java @@ -45,6 +45,7 @@ void execute( /** Data class that encapsulates ExprValue. */ @Data class QueryResponse { + private String status = "COMPLETED"; private final Schema schema; private final List results; private final Cursor cursor; diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 2ac480bbf2..302f49409e 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -75,6 +75,17 @@ public DataSourceMetadata getDataSourceMetadata(String datasourceName) { return dataSourceMetadataOptional.get(); } + @Override + public DataSourceMetadata getRawDataSourceMetadata(String datasourceName) { + Optional dataSourceMetadataOptional = + getDataSourceMetadataFromName(datasourceName); + if (dataSourceMetadataOptional.isEmpty()) { + throw new IllegalArgumentException( + "DataSource with name: " + datasourceName + " doesn't exist."); + } + return dataSourceMetadataOptional.get(); + } + @Override public DataSource getDataSource(String dataSourceName) { Optional dataSourceMetadataOptional = diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 0404900450..dc92f9ebb3 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -162,6 +162,7 @@ configurations.all { resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk7:1.5.31" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" + resolutionStrategy.force "com.amazonaws:aws-java-sdk-core:1.12.545" } configurations { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 48ceacaf10..76bda07607 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -129,6 +129,12 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_ENGINE_CONFIG = + Setting.simpleString( + Key.SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue(), + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -193,6 +199,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.DATASOURCES_URI_HOSTS_DENY_LIST, DATASOURCE_URI_HOSTS_DENY_LIST, new Updater(Key.DATASOURCES_URI_HOSTS_DENY_LIST)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_ENGINE_CONFIG, + SPARK_EXECUTION_ENGINE_CONFIG, + new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -257,6 +269,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_WINDOW_SETTING) .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) + .add(SPARK_EXECUTION_ENGINE_CONFIG) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 80e1a6b1a3..5d7aa83911 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -6,9 +6,15 @@ package org.opensearch.sql.plugin; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -83,6 +89,15 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; +import org.opensearch.sql.spark.client.EmrServerlessClient; +import org.opensearch.sql.spark.client.EmrServerlessClientImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.jobs.JobMetadataStorageService; +import org.opensearch.sql.spark.jobs.OpensearchJobMetadataStorageService; +import org.opensearch.sql.spark.response.SparkResponseReader; import org.opensearch.sql.spark.rest.RestJobManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; @@ -110,6 +125,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; + private JobExecutorService jobExecutorService; private Injector injector; public String name() { @@ -202,6 +218,7 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); + this.jobExecutorService = createJobManagementService(); ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); @@ -213,7 +230,7 @@ public Collection createComponents( }); injector = modules.createInjector(); - return ImmutableList.of(dataSourceService); + return ImmutableList.of(dataSourceService, jobExecutorService); } @Override @@ -270,4 +287,34 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } + + private JobExecutorService createJobManagementService() { + JobMetadataStorageService jobMetadataStorageService = + new OpensearchJobMetadataStorageService(client, clusterService); + EmrServerlessClient emrServerlessClient = createEMRServerlessClient(); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(emrServerlessClient, this.dataSourceService); + return new JobExecutorServiceImpl( + jobMetadataStorageService, sparkQueryDispatcher, pluginSettings); + } + + private EmrServerlessClient createEMRServerlessClient() { + String sparkExecutionEngineConfigString = + this.pluginSettings.getSettingValue( + org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString); + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(sparkExecutionEngineConfig.getRegion()) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl( + awsemrServerless, new SparkResponseReader(client, null, STEP_ID_FIELD)); + }); + } } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index aec517aa84..fcf70c01f9 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -15,4 +15,13 @@ grant { // ml-commons client permission java.lang.RuntimePermission "setContextClassLoader"; + + // aws credentials + permission java.io.FilePermission "${user.home}${/}.aws${/}*", "read"; + + // Permissions for aws emr servless sdk + permission javax.management.MBeanServerPermission "createMBeanServer"; + permission javax.management.MBeanServerPermission "findMBeanServer"; + permission javax.management.MBeanPermission "com.amazonaws.metrics.*", "*"; + permission javax.management.MBeanTrustPermission "register"; }; diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java index 03be0875cf..9d6deb84c5 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java @@ -12,6 +12,7 @@ import java.util.Map; import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; @@ -25,6 +26,8 @@ @RequiredArgsConstructor public class QueryResult implements Iterable { + @Setter @Getter private String status; + @Getter private final ExecutionEngine.Schema schema; /** Results which are collection of expression. */ diff --git a/spark/build.gradle b/spark/build.gradle index b93e3327ce..cbbb6caf4b 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -15,11 +15,14 @@ repositories { dependencies { api project(':core') + implementation project(':protocol') implementation project(':datasources') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.json', name: 'json', version: '20230227' - implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' + implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java index 1a3304994b..b4dfd3812a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -23,7 +23,7 @@ import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.sql.spark.helper.FlintHelper; -import org.opensearch.sql.spark.response.SparkResponse; +import org.opensearch.sql.spark.response.SparkResponseReader; public class EmrClientImpl implements SparkClient { private final AmazonElasticMapReduce emr; @@ -31,25 +31,26 @@ public class EmrClientImpl implements SparkClient { private final FlintHelper flint; private final String sparkApplicationJar; private static final Logger logger = LogManager.getLogger(EmrClientImpl.class); - private SparkResponse sparkResponse; + private SparkResponseReader sparkResponseReader; /** * Constructor for EMR Client Implementation. * * @param emr EMR helper * @param flint Opensearch args for flint integration jar - * @param sparkResponse Response object to help with retrieving results from Opensearch index + * @param sparkResponseReader Response object to help with retrieving results from Opensearch + * index */ public EmrClientImpl( AmazonElasticMapReduce emr, String emrCluster, FlintHelper flint, - SparkResponse sparkResponse, + SparkResponseReader sparkResponseReader, String sparkApplicationJar) { this.emr = emr; this.emrCluster = emrCluster; this.flint = flint; - this.sparkResponse = sparkResponse; + this.sparkResponseReader = sparkResponseReader; this.sparkApplicationJar = sparkApplicationJar == null ? SPARK_SQL_APPLICATION_JAR : sparkApplicationJar; } @@ -57,7 +58,7 @@ public EmrClientImpl( @Override public JSONObject sql(String query) throws IOException { runEmrApplication(query); - return sparkResponse.getResultFromOpensearchIndex(); + return sparkResponseReader.getResultFromOpensearchIndex(); } @VisibleForTesting @@ -98,7 +99,7 @@ void runEmrApplication(String query) { new DescribeStepRequest().withClusterId(emrCluster).withStepId(stepId); waitForStepExecution(stepRequest); - sparkResponse.setValue(stepId); + sparkResponseReader.setValue(stepId); } @SneakyThrows diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java new file mode 100644 index 0000000000..61729b5c12 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java @@ -0,0 +1,17 @@ +package org.opensearch.sql.spark.client; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import org.opensearch.sql.spark.helper.FlintHelper; + +public interface EmrServerlessClient { + + String startJobRun( + String applicationId, + String query, + String datasourceRoleArn, + String executionRoleArn, + String datasourceName, + FlintHelper flintHelper); + + GetJobRunResult getJobResult(String applicationId, String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java new file mode 100644 index 0000000000..4bdf6863ff --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CATALOG_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_CATALOG_HIVE_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.GetJobRunRequest; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobDriver; +import com.amazonaws.services.emrserverless.model.SparkSubmit; +import com.amazonaws.services.emrserverless.model.StartJobRunRequest; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Set; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponseReader; + +public class EmrServerlessClientImpl implements EmrServerlessClient { + + private final AWSEMRServerless emrServerless; + private final String sparkApplicationJar; + private SparkResponseReader sparkResponseReader; + private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class); + private static final Set terminalStates = Set.of("CANCELLED", "FAILED", "SUCCESS"); + private static final String JOB_NAME = "flint-opensearch-query"; + + public EmrServerlessClientImpl( + AWSEMRServerless emrServerless, SparkResponseReader sparkResponseReader) { + this.emrServerless = emrServerless; + this.sparkApplicationJar = SPARK_SQL_APPLICATION_JAR; + this.sparkResponseReader = sparkResponseReader; + } + + @Override + public String startJobRun( + String applicationId, + String query, + String datasourceRoleArn, + String executionRoleArn, + String datasourceName, + FlintHelper flint) { + StartJobRunRequest request = + new StartJobRunRequest() + .withName(JOB_NAME) + .withApplicationId(applicationId) + .withExecutionRoleArn(executionRoleArn) + .withJobDriver( + new JobDriver() + .withSparkSubmit( + new SparkSubmit() + .withEntryPoint(sparkApplicationJar) + .withEntryPointArguments(query, SPARK_INDEX_NAME) + .withSparkSubmitParameters( + "--class org.opensearch.sql.FlintJob --conf" + + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf" + + " spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=" + + datasourceRoleArn + + " --conf spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=" + + datasourceRoleArn + + " --conf" + + " spark.hadoop.aws.catalog.credentials.provider.factory.class=" + + "com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" + + " --conf spark.hive.metastore.glue.role.arn=" + + datasourceRoleArn + + " --conf spark.jars=" + + GLUE_CATALOG_HIVE_JAR + + "," + + FLINT_CATALOG_JAR + + " --conf spark.jars.packages=" + + "org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT" + + " --conf spark.jars.repositories=" + + "https://aws.oss.sonatype.org/content/repositories/snapshots" + + " --conf" + + " spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf" + + " spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf spark.datasource.flint.host=" + + flint.getFlintHost() + + " --conf spark.datasource.flint.port=" + + flint.getFlintPort() + + " --conf spark.datasource.flint.scheme=" + + flint.getFlintScheme() + + " --conf spark.datasource.flint.auth=" + + flint.getFlintAuth() + + " --conf spark.datasource.flint.region=" + + flint.getFlintRegion() + + " --conf" + + " spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf" + + " spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions" + + " --conf spark.hadoop.hive.metastore.client.factory.class=" + + "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory" + + " --conf spark.sql.catalog." + + datasourceName + + "=" + + "org.opensearch.sql.FlintDelegateCatalog"))); + StartJobRunResult startJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.startJobRun(request)); + logger.info("Job Run ID: " + startJobRunResult.getJobRunId()); + sparkResponseReader.setValue(startJobRunResult.getJobRunId()); + return startJobRunResult.getJobRunId(); + } + + @Override + public GetJobRunResult getJobResult(String applicationId, String jobId) { + GetJobRunRequest request = + new GetJobRunRequest().withApplicationId(applicationId).withJobRunId(jobId); + GetJobRunResult getJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.getJobRun(request)); + logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); + return getJobRunResult; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java new file mode 100644 index 0000000000..3879f7c566 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.google.gson.Gson; +import lombok.Data; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class SparkExecutionEngineConfig { + private String applicationId; + private String region; + + public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) { + return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 65d5a01ba2..fc41c9f1a3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -7,11 +7,17 @@ public class SparkConstants { public static final String EMR = "emr"; + public static final String EMRS = "emr-serverless"; public static final String STEP_ID_FIELD = "stepId.keyword"; - public static final String SPARK_SQL_APPLICATION_JAR = "s3://spark-datasource/sql-job.jar"; + public static final String SPARK_SQL_APPLICATION_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; public static final String SPARK_INDEX_NAME = ".query_execution_result"; public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + public static final String GLUE_CATALOG_HIVE_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar"; + public static final String FLINT_CATALOG_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar"; public static final String FLINT_DEFAULT_HOST = "localhost"; public static final String FLINT_DEFAULT_PORT = "9200"; public static final String FLINT_DEFAULT_SCHEME = "http"; diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java new file mode 100644 index 0000000000..21414e6f26 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INTEGRATION_JAR; + +import java.net.URI; +import java.net.URISyntaxException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.client.EmrServerlessClient; +import org.opensearch.sql.spark.helper.FlintHelper; + +@AllArgsConstructor +public class SparkQueryDispatcher { + + private EmrServerlessClient emrServerlessClient; + + private DataSourceService dataSourceService; + + public String dispatch(String applicationId, String query) { + String datasourceName = getDataSourceName(query); + DataSourceMetadata dataSourceMetadata = + dataSourceService.getRawDataSourceMetadata(datasourceName); + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + return emrServerlessClient.startJobRun( + applicationId, + query, + getDataSourceRoleARN(dataSourceMetadata), + "arn:aws:iam::270824043731:role/emr-job-execution-role", + datasourceName, + getFlintHelper(dataSourceMetadata)); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }); + } + + public JSONObject getQueryResponse(String applicationId, String jobId) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerlessClient.getJobResult(applicationId, jobId)); + } + + private String getDataSourceName(String query) { + return "my_glue"; + } + + private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { + return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + } + + private FlintHelper getFlintHelper(DataSourceMetadata dataSourceMetadata) + throws URISyntaxException { + String opensearchuri = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.uri"); + URI uri = new URI(opensearchuri); + String auth = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.auth"); + String region = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.region"); + return new FlintHelper( + FLINT_INTEGRATION_JAR, + uri.getHost(), + String.valueOf(uri.getPort()), + uri.getScheme(), + auth, + region); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java index 823ad2da29..77783c436f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -15,7 +15,6 @@ import org.json.JSONObject; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; -import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; @@ -81,7 +80,7 @@ private static LinkedHashMap extractRow( } else if (type == ExprCoreType.FLOAT) { linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); } else if (type == ExprCoreType.DATE) { - linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); } else if (type == ExprCoreType.TIMESTAMP) { linkedHashMap.put( column.getName(), new ExprTimestampValue(row.getString(column.getName()))); diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java new file mode 100644 index 0000000000..a86ed602b5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; + +public interface JobExecutorService { + + CreateJobResponse createJob(CreateJobRequest createJobRequest); + + ExecutionEngine.QueryResponse getJobResults(String jobId); + + String getJob(String jobId); + + String cancelJob(String jobIds); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java new file mode 100644 index 0000000000..084306da96 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.jobs.model.JobMetadata; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; + +@AllArgsConstructor +public class JobExecutorServiceImpl implements JobExecutorService { + private JobMetadataStorageService jobMetadataStorageService; + private SparkQueryDispatcher sparkQueryDispatcher; + private Settings settings; + + @Override + public CreateJobResponse createJob(CreateJobRequest createJobRequest) { + String sparkExecutionEngineConfigString = + settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString)); + String jobId = + sparkQueryDispatcher.dispatch( + sparkExecutionEngineConfig.getApplicationId(), createJobRequest.getQuery()); + jobMetadataStorageService.storeJobMetadata( + new JobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); + return new CreateJobResponse(jobId); + } + + @Override + public ExecutionEngine.QueryResponse getJobResults(String jobId) { + Optional jobMetadata = jobMetadataStorageService.getJobMetadata(jobId); + if (jobMetadata.isPresent()) { + JSONObject jsonObject = + sparkQueryDispatcher.getQueryResponse( + jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId()); + if (!jsonObject.keySet().contains("status")) { + DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = + new DefaultSparkSqlFunctionResponseHandle(jsonObject); + List result = new ArrayList<>(); + while (sparkSqlFunctionResponseHandle.hasNext()) { + result.add(sparkSqlFunctionResponseHandle.next()); + } + return new ExecutionEngine.QueryResponse( + sparkSqlFunctionResponseHandle.schema(), result, Cursor.None); + } else { + ExecutionEngine.QueryResponse queryResponse = + new ExecutionEngine.QueryResponse(null, null, Cursor.None); + queryResponse.setStatus(jsonObject.getString("status")); + return queryResponse; + } + } + return null; + } + + @Override + public String getJob(String jobId) { + return null; + } + + @Override + public String cancelJob(String jobIds) { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java new file mode 100644 index 0000000000..52873d4c25 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java @@ -0,0 +1,11 @@ +package org.opensearch.sql.spark.jobs; + +import java.util.Optional; +import org.opensearch.sql.spark.jobs.model.JobMetadata; + +public interface JobMetadataStorageService { + + void storeJobMetadata(JobMetadata jobMetadata); + + Optional getJobMetadata(String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java new file mode 100644 index 0000000000..2e42eca245 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java @@ -0,0 +1,162 @@ +package org.opensearch.sql.spark.jobs; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.spark.jobs.model.JobMetadata; + +public class OpensearchJobMetadataStorageService implements JobMetadataStorageService { + + public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; + private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = + "job-metadata-index-mapping.yml"; + private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = + "job-metadata-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + private final Client client; + private final ClusterService clusterService; + + /** + * This class implements JobMetadataStorageService interface using OpenSearch as underlying + * storage. + * + * @param client opensearch NodeClient. + * @param clusterService ClusterService. + */ + public OpensearchJobMetadataStorageService(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public void storeJobMetadata(JobMetadata jobMetadata) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createDataSourcesIndex(); + } + IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); + indexRequest.id(jobMetadata.getJobId()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture indexResponseActionFuture; + IndexResponse indexResponse; + try (ThreadContext.StoredContext storedContext = + client.threadPool().getThreadContext().stashContext()) { + indexRequest.source(JobMetadata.convertToXContent(jobMetadata)); + indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("JobMetadata : {} successfully created", jobMetadata.getJobId()); + } else { + throw new RuntimeException( + "Saving dataSource metadata information failed with result : " + + indexResponse.getResult().getLowercase()); + } + } + + @Override + public Optional getJobMetadata(String jobId) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createDataSourcesIndex(); + return Optional.empty(); + } + return searchInDataSourcesIndex(QueryBuilders.termQuery("jobId", jobId)).stream().findFirst(); + } + + private void createDataSourcesIndex() { + try { + InputStream mappingFileStream = + OpensearchJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpensearchJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); + createIndexRequest + .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) + .settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + + JOB_METADATA_INDEX + + " index:: " + + e.getMessage()); + } + } + + private List searchInDataSourcesIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(JOB_METADATA_INDEX); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(1); + searchRequest.source(searchSourceBuilder); + // https://github.com/opensearch-project/sql/issues/1801. + searchRequest.preference("_primary_first"); + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching dataSource metadata information failed with status : " + + searchResponse.status()); + } else { + List list = new ArrayList<>(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + String sourceAsString = searchHit.getSourceAsString(); + JobMetadata jobMetadata; + try { + jobMetadata = JobMetadata.toJobMetadata(sourceAsString); + } catch (IOException e) { + throw new RuntimeException(e); + } + list.add(jobMetadata); + } + return list; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java new file mode 100644 index 0000000000..4af54730a0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java @@ -0,0 +1,87 @@ +package org.opensearch.sql.spark.jobs.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Data; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +/** This class models all the metadata required for a job. */ +@Data +@AllArgsConstructor +public class JobMetadata { + private String jobId; + private String applicationId; + + /** + * Converts JobMetadata to XContentBuilder. + * + * @param metadata metadata. + * @return XContentBuilder {@link XContentBuilder} + * @throws Exception Exception. + */ + public static XContentBuilder convertToXContent(JobMetadata metadata) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("jobId", metadata.getJobId()); + builder.field("applicationId", metadata.getApplicationId()); + builder.endObject(); + return builder; + } + + /** + * Converts json string to DataSourceMetadata. + * + * @param json jsonstring. + * @return jobmetadata {@link JobMetadata} + * @throws java.io.IOException IOException. + */ + public static JobMetadata toJobMetadata(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toJobMetadata(parser); + } + } + + /** + * Convert xcontent parser to JobMetadata. + * + * @param parser parser. + * @return JobMetadata {@link JobMetadata} + * @throws IOException IOException. + */ + public static JobMetadata toJobMetadata(XContentParser parser) throws IOException { + String jobId = null; + String applicationId = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case "jobId": + jobId = parser.textOrNull(); + break; + case "applicationId": + applicationId = parser.textOrNull(); + break; + default: + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + if (jobId == null || applicationId == null) { + throw new IllegalArgumentException("jobId and applicationId are required fields."); + } + return new JobMetadata(jobId, applicationId); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponseReader.java similarity index 88% rename from spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java rename to spark/src/main/java/org/opensearch/sql/spark/response/SparkResponseReader.java index 3edb541384..1796c12b42 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponseReader.java @@ -26,7 +26,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; @Data -public class SparkResponse { +public class SparkResponseReader { private final Client client; private String value; private final String field; @@ -39,7 +39,7 @@ public class SparkResponse { * @param value Identifier field value * @param field Identifier field name */ - public SparkResponse(Client client, String value, String field) { + public SparkResponseReader(Client client, String value, String field) { this.client = client; this.value = value; this.field = field; @@ -69,12 +69,15 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { + " index failed with status : " + searchResponse.status()); } else { - JSONObject data = new JSONObject(); - for (SearchHit searchHit : searchResponse.getHits().getHits()) { - data.put("data", searchHit.getSourceAsMap()); - deleteInSparkIndex(searchHit.getId()); + if (searchResponse.getHits().getTotalHits().value == 0) { + return null; + } else { + JSONObject data = new JSONObject(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + data.put("data", searchHit.getSourceAsMap()); + } + return data; } - return data; } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java index 669cbb6aca..f386dfb7b6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java @@ -138,7 +138,7 @@ public void onResponse(CreateJobActionResponse createJobActionResponse) { new BytesRestResponse( RestStatus.CREATED, "application/json; charset=UTF-8", - submitJobRequest.getQuery())); + createJobActionResponse.getResult())); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java new file mode 100644 index 0000000000..9f4990de34 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class CreateJobResponse { + private String jobId; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java index 467bacbaea..bf1c4cde05 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -26,7 +26,7 @@ import org.opensearch.sql.spark.client.EmrClientImpl; import org.opensearch.sql.spark.client.SparkClient; import org.opensearch.sql.spark.helper.FlintHelper; -import org.opensearch.sql.spark.response.SparkResponse; +import org.opensearch.sql.spark.response.SparkResponseReader; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.sql.storage.StorageEngine; @@ -94,7 +94,7 @@ StorageEngine getStorageEngine(Map requiredConfig) { requiredConfig.get(FLINT_SCHEME), requiredConfig.get(FLINT_AUTH), requiredConfig.get(FLINT_REGION)), - new SparkResponse(client, null, STEP_ID_FIELD), + new SparkResponseReader(client, null, STEP_ID_FIELD), requiredConfig.get(SPARK_SQL_APPLICATION)); }); } else { diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java index 53ae9fad90..1740ff92c1 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java @@ -12,6 +12,11 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; import org.opensearch.tasks.Task; @@ -20,20 +25,37 @@ public class TransportCreateJobRequestAction extends HandledTransportAction { + private JobExecutorService jobExecutorService; + public static final String NAME = "cluster:admin/opensearch/ql/jobs/create"; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CreateJobActionResponse::new); @Inject public TransportCreateJobRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + JobExecutorServiceImpl jobManagementService) { super(NAME, transportService, actionFilters, CreateJobActionRequest::new); + this.jobExecutorService = jobManagementService; } @Override protected void doExecute( Task task, CreateJobActionRequest request, ActionListener listener) { - String responseContent = "submitted_job"; - listener.onResponse(new CreateJobActionResponse(responseContent)); + try { + CreateJobRequest createJobRequest = request.getCreateJobRequest(); + CreateJobResponse createJobResponse = jobExecutorService.createJob(createJobRequest); + String responseContent = + new JsonResponseFormatter(JsonResponseFormatter.Style.PRETTY) { + @Override + protected Object buildJsonObject(CreateJobResponse response) { + return response; + } + }.format(createJobResponse); + listener.onResponse(new CreateJobActionResponse(responseContent)); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java index 6aba1b48b6..0d5a506d07 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java @@ -7,11 +7,19 @@ package org.opensearch.sql.spark.transport; +import org.json.JSONObject; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.protocol.response.QueryResult; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.protocol.response.format.ResponseFormatter; +import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; import org.opensearch.tasks.Task; @@ -21,14 +29,19 @@ public class TransportGetQueryResultRequestAction extends HandledTransportAction< GetJobQueryResultActionRequest, GetJobQueryResultActionResponse> { + private JobExecutorService jobExecutorService; + public static final String NAME = "cluster:admin/opensearch/ql/jobs/result"; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, GetJobQueryResultActionResponse::new); @Inject public TransportGetQueryResultRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + JobExecutorServiceImpl jobManagementService) { super(NAME, transportService, actionFilters, GetJobQueryResultActionRequest::new); + this.jobExecutorService = jobManagementService; } @Override @@ -36,7 +49,26 @@ protected void doExecute( Task task, GetJobQueryResultActionRequest request, ActionListener listener) { - String responseContent = "job result"; - listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); + try { + String jobId = request.getJobId(); + ExecutionEngine.QueryResponse queryResponse = jobExecutorService.getJobResults(jobId); + if (!queryResponse.getStatus().equals("COMPLETED")) { + JSONObject jsonObject = new JSONObject(); + jsonObject.put("status", queryResponse.getStatus()); + listener.onResponse(new GetJobQueryResultActionResponse(jsonObject.toString())); + } else { + ResponseFormatter formatter = + new SimpleJsonResponseFormatter(JsonResponseFormatter.Style.PRETTY); + String responseContent = + formatter.format( + new QueryResult( + queryResponse.getSchema(), + queryResponse.getResults(), + queryResponse.getCursor())); + listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); + } + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml new file mode 100644 index 0000000000..ec2c83a4df --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -0,0 +1,20 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + jobId: + type: text + fields: + keyword: + type: keyword + applicationId: + type: text + fields: + keyword: + type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml new file mode 100644 index 0000000000..be93f4645c --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java index 93dc0d6bc8..ea318f91b7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -24,14 +24,14 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.helper.FlintHelper; -import org.opensearch.sql.spark.response.SparkResponse; +import org.opensearch.sql.spark.response.SparkResponseReader; @ExtendWith(MockitoExtension.class) public class EmrClientImplTest { @Mock private AmazonElasticMapReduce emr; @Mock private FlintHelper flint; - @Mock private SparkResponse sparkResponse; + @Mock private SparkResponseReader sparkResponseReader; @Test @SneakyThrows @@ -48,7 +48,7 @@ void testRunEmrApplication() { when(emr.describeStep(any())).thenReturn(describeStepResult); EmrClientImpl emrClientImpl = - new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponseReader, null); emrClientImpl.runEmrApplication(QUERY); } @@ -67,7 +67,7 @@ void testRunEmrApplicationFailed() { when(emr.describeStep(any())).thenReturn(describeStepResult); EmrClientImpl emrClientImpl = - new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponseReader, null); RuntimeException exception = Assertions.assertThrows( RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); @@ -89,7 +89,7 @@ void testRunEmrApplicationCancelled() { when(emr.describeStep(any())).thenReturn(describeStepResult); EmrClientImpl emrClientImpl = - new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponseReader, null); RuntimeException exception = Assertions.assertThrows( RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); @@ -121,7 +121,7 @@ void testRunEmrApplicationRunnning() { .thenReturn(completedDescribeStepResult); EmrClientImpl emrClientImpl = - new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponseReader, null); emrClientImpl.runEmrApplication(QUERY); } @@ -148,11 +148,11 @@ void testSql() { when(emr.describeStep(any())) .thenReturn(runningDescribeStepResult) .thenReturn(completedDescribeStepResult); - when(sparkResponse.getResultFromOpensearchIndex()) + when(sparkResponseReader.getResultFromOpensearchIndex()) .thenReturn(new JSONObject(getJson("select_query_response.json"))); EmrClientImpl emrClientImpl = - new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponseReader, null); emrClientImpl.sql(QUERY); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java new file mode 100644 index 0000000000..49c6e4256b --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -0,0 +1,69 @@ +/* Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_DATASOURCE_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponseReader; + +@ExtendWith(MockitoExtension.class) +public class EmrServerlessClientImplTest { + @Mock private AWSEMRServerless emrServerless; + @Mock private FlintHelper flint; + @Mock private SparkResponseReader sparkResponseReader; + + @Test + void testStartJobRun() { + StartJobRunResult response = new StartJobRunResult(); + when(emrServerless.startJobRun(any())).thenReturn(response); + + EmrServerlessClientImpl emrServerlessClient = + new EmrServerlessClientImpl(emrServerless, sparkResponseReader); + emrServerlessClient.startJobRun( + EMRS_APPLICATION_ID, + QUERY, + EMRS_DATASOURCE_ROLE, + EMRS_EXECUTION_ROLE, + TEST_DATASOURCE_NAME, + flint); + } + + @Test + void testGetJobRunState() { + JobRun jobRun = new JobRun(); + jobRun.setState("Running"); + GetJobRunResult response = new GetJobRunResult(); + response.setJobRun(jobRun); + when(emrServerless.getJobRun(any())).thenReturn(response); + EmrServerlessClientImpl emrServerlessClient = + new EmrServerlessClientImpl(emrServerless, sparkResponseReader); + emrServerlessClient.getJobRunState(EMRS_APPLICATION_ID, "123"); + } + + @Test + void testCancelJobRun() { + when(emrServerless.cancelJobRun(any())).thenReturn(new CancelJobRunResult()); + + EmrServerlessClientImpl emrServerlessClient = + new EmrServerlessClientImpl(emrServerless, sparkResponseReader); + emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, "123"); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 2b1020568a..0da313255d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -7,5 +7,10 @@ public class TestConstants { public static final String QUERY = "select 1"; + public static final String TEST_DATASOURCE_NAME = "test_datasource_name"; public static final String EMR_CLUSTER_ID = "j-123456789"; + public static final String EMRS_APPLICATION_ID = "xxxxx"; + public static final String EMRS_EXECUTION_ROLE = "execution_role"; + public static final String EMRS_DATASOURCE_ROLE = "datasource_role"; + public static final String EMRS_JOB_NAME = "job_name"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseReaderTest.java similarity index 77% rename from spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java rename to spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseReaderTest.java index 211561ac72..ef7a0b078c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseReaderTest.java @@ -31,7 +31,7 @@ import org.opensearch.search.SearchHits; @ExtendWith(MockitoExtension.class) -public class SparkResponseTest { +public class SparkResponseReaderTest { @Mock private Client client; @Mock private SearchResponse searchResponse; @Mock private DeleteResponse deleteResponse; @@ -54,8 +54,9 @@ public void testGetResultFromOpensearchIndex() { when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - assertFalse(sparkResponse.getResultFromOpensearchIndex().isEmpty()); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); + assertFalse(sparkResponseReader.getResultFromOpensearchIndex().isEmpty()); } @Test @@ -64,9 +65,11 @@ public void testInvalidSearchResponse() { when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); RuntimeException exception = - assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); + assertThrows( + RuntimeException.class, () -> sparkResponseReader.getResultFromOpensearchIndex()); Assertions.assertEquals( "Fetching result from " + SPARK_INDEX_NAME @@ -78,15 +81,17 @@ public void testInvalidSearchResponse() { @Test public void testSearchFailure() { when(client.search(any())).thenThrow(RuntimeException.class); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponseReader.getResultFromOpensearchIndex()); } @Test public void testDeleteFailure() { when(client.delete(any())).thenThrow(RuntimeException.class); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - assertThrows(RuntimeException.class, () -> sparkResponse.deleteInSparkIndex("id")); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponseReader.deleteInSparkIndex("id")); } @Test @@ -95,10 +100,11 @@ public void testNotFoundDeleteResponse() { when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); RuntimeException exception = assertThrows( - ResourceNotFoundException.class, () -> sparkResponse.deleteInSparkIndex("123")); + ResourceNotFoundException.class, () -> sparkResponseReader.deleteInSparkIndex("123")); Assertions.assertEquals("Spark result with id 123 doesn't exist", exception.getMessage()); } @@ -108,9 +114,10 @@ public void testInvalidDeleteResponse() { when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); - SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + SparkResponseReader sparkResponseReader = + new SparkResponseReader(client, EMR_CLUSTER_ID, "stepId"); RuntimeException exception = - assertThrows(RuntimeException.class, () -> sparkResponse.deleteInSparkIndex("123")); + assertThrows(RuntimeException.class, () -> sparkResponseReader.deleteInSparkIndex("123")); Assertions.assertEquals( "Deleting spark result information failed with : noop", exception.getMessage()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java index 4357899368..99337aab6e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java @@ -19,6 +19,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateJobRequest; import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; @@ -30,6 +31,7 @@ public class TransportCreateJobRequestActionTest { @Mock private TransportService transportService; @Mock private TransportCreateJobRequestAction action; + @Mock private JobExecutorServiceImpl jobManagementService; @Mock private Task task; @Mock private ActionListener actionListener; @@ -38,7 +40,8 @@ public class TransportCreateJobRequestActionTest { @BeforeEach public void setUp() { action = - new TransportCreateJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + new TransportCreateJobRequestAction( + transportService, new ActionFilters(new HashSet<>()), jobManagementService); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java index f22adead49..fd6dfef11b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java @@ -19,6 +19,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; import org.opensearch.tasks.Task; @@ -31,6 +32,7 @@ public class TransportGetQueryResultRequestActionTest { @Mock private TransportGetQueryResultRequestAction action; @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private JobExecutorServiceImpl jobManagementService; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; @@ -39,7 +41,7 @@ public class TransportGetQueryResultRequestActionTest { public void setUp() { action = new TransportGetQueryResultRequestAction( - transportService, new ActionFilters(new HashSet<>())); + transportService, new ActionFilters(new HashSet<>()), jobManagementService); } @Test