From 63a0a75920efe88cd5c8e722788ac8770addb648 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 8 Sep 2023 17:03:37 -0700 Subject: [PATCH] Create Job API Signed-off-by: Vamsi Manohar --- common/build.gradle | 4 +- .../sql/common/setting/Settings.java | 2 +- .../sql/datasource/DataSourceService.java | 9 + .../sql/analysis/AnalyzerTestBase.java | 5 + .../service/DataSourceServiceImpl.java | 41 ++- .../service/DataSourceServiceImplTest.java | 30 ++- docs/user/interfaces/jobinterface.rst | 18 ++ integ-test/build.gradle | 1 + .../setting/OpenSearchSettings.java | 13 + plugin/build.gradle | 4 +- .../org/opensearch/sql/plugin/SQLPlugin.java | 49 +++- .../plugin-metadata/plugin-security.policy | 9 + spark/build.gradle | 10 +- .../sql/spark/client/EmrClientImpl.java | 4 +- .../sql/spark/client/EmrServerlessClient.java | 15 ++ .../spark/client/EmrServerlessClientImpl.java | 68 +++++ .../config/SparkExecutionEngineConfig.java | 22 ++ .../spark/data/constants/SparkConstants.java | 52 +++- .../dispatcher/SparkQueryDispatcher.java | 101 ++++++++ ...DefaultSparkSqlFunctionResponseHandle.java | 3 +- .../sql/spark/jobs/JobExecutorService.java | 30 +++ .../spark/jobs/JobExecutorServiceImpl.java | 76 ++++++ .../spark/jobs/JobMetadataStorageService.java | 11 + .../OpensearchJobMetadataStorageService.java | 161 ++++++++++++ .../jobs/exceptions/JobNotFoundException.java | 15 ++ .../jobs/model/JobExecutionResponse.java | 13 + .../sql/spark/jobs/model/JobMetadata.java | 93 +++++++ .../model/S3GlueSparkSubmitParameters.java | 90 +++++++ .../response/JobExecutionResponseReader.java | 67 +++++ .../sql/spark/response/SparkResponse.java | 8 +- .../spark/rest/RestJobManagementAction.java | 2 +- .../spark/rest/model/CreateJobResponse.java | 15 ++ .../TransportCreateJobRequestAction.java | 28 ++- .../TransportGetQueryResultRequestAction.java | 39 ++- .../resources/job-metadata-index-mapping.yml | 20 ++ .../resources/job-metadata-index-settings.yml | 11 + .../client/EmrServerlessClientImplTest.java | 48 ++++ .../sql/spark/constants/TestConstants.java | 7 + .../dispatcher/SparkQueryDispatcherTest.java | 178 +++++++++++++ .../jobs/JobExecutorServiceImplTest.java | 119 +++++++++ ...ensearchJobMetadataStorageServiceTest.java | 238 ++++++++++++++++++ .../JobExecutionResponseReaderTest.java | 78 ++++++ .../sql/spark/response/SparkResponseTest.java | 4 +- .../TransportCreateJobRequestActionTest.java | 30 ++- ...nsportGetQueryResultRequestActionTest.java | 92 ++++++- 45 files changed, 1873 insertions(+), 60 deletions(-) create mode 100644 docs/user/interfaces/jobinterface.rst create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/exceptions/JobNotFoundException.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobExecutionResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/jobs/model/S3GlueSparkSubmitParameters.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java create mode 100644 spark/src/main/resources/job-metadata-index-mapping.yml create mode 100644 spark/src/main/resources/job-metadata-index-settings.yml create mode 100644 spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImplTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageServiceTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/response/JobExecutionResponseReaderTest.java diff --git a/common/build.gradle b/common/build.gradle index 5cf219fbae..109cad59cb 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -39,8 +39,8 @@ dependencies { api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' - api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1' - api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.545' implementation "com.github.seancfoley:ipaddress:5.4.0" testImplementation group: 'junit', name: 'junit', version: '4.13.2' diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index be780e8d80..8daf0e9bf6 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -35,7 +35,7 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), - + SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), CLUSTER_NAME("cluster.name"); @Getter private final String keyValue; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 3d6ddc864e..6dace50f99 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -39,6 +39,15 @@ public interface DataSourceService { */ DataSourceMetadata getDataSourceMetadata(String name); + /** + * Returns dataSourceMetadata object with specific name. The returned objects contain all the + * metadata information without any filtering. + * + * @param name name of the {@link DataSource}. + * @return set of {@link DataSourceMetadata}. + */ + DataSourceMetadata getRawDataSourceMetadata(String name); + /** * Register {@link DataSource} defined by {@link DataSourceMetadata}. * diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index f09bc5d380..a16d57673e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -208,6 +208,11 @@ public DataSourceMetadata getDataSourceMetadata(String name) { return null; } + @Override + public DataSourceMetadata getRawDataSourceMetadata(String name) { + return null; + } + @Override public void createDataSource(DataSourceMetadata metadata) { throw new UnsupportedOperationException("unsupported operation"); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 2ac480bbf2..d6c1907f84 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -64,29 +64,17 @@ public Set getDataSourceMetadata(boolean isDefaultDataSource } @Override - public DataSourceMetadata getDataSourceMetadata(String datasourceName) { - Optional dataSourceMetadataOptional = - getDataSourceMetadataFromName(datasourceName); - if (dataSourceMetadataOptional.isEmpty()) { - throw new IllegalArgumentException( - "DataSource with name: " + datasourceName + " doesn't exist."); - } - removeAuthInfo(dataSourceMetadataOptional.get()); - return dataSourceMetadataOptional.get(); + public DataSourceMetadata getDataSourceMetadata(String dataSourceName) { + DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); + removeAuthInfo(dataSourceMetadata); + return dataSourceMetadata; } @Override public DataSource getDataSource(String dataSourceName) { - Optional dataSourceMetadataOptional = - getDataSourceMetadataFromName(dataSourceName); - if (dataSourceMetadataOptional.isEmpty()) { - throw new DataSourceNotFoundException( - String.format("DataSource with name %s doesn't exist.", dataSourceName)); - } else { - DataSourceMetadata dataSourceMetadata = dataSourceMetadataOptional.get(); - this.dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - return dataSourceLoaderCache.getOrLoadDataSource(dataSourceMetadata); - } + DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); + this.dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); + return dataSourceLoaderCache.getOrLoadDataSource(dataSourceMetadata); } @Override @@ -146,11 +134,20 @@ private void validateDataSourceMetaData(DataSourceMetadata metadata) { + " Properties are required parameters."); } - private Optional getDataSourceMetadataFromName(String dataSourceName) { + @Override + public DataSourceMetadata getRawDataSourceMetadata(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { - return Optional.of(DataSourceMetadata.defaultOpenSearchDataSourceMetadata()); + return DataSourceMetadata.defaultOpenSearchDataSourceMetadata(); + } else { - return this.dataSourceMetadataStorage.getDataSourceMetadata(dataSourceName); + Optional dataSourceMetadataOptional = + this.dataSourceMetadataStorage.getDataSourceMetadata(dataSourceName); + if (dataSourceMetadataOptional.isEmpty()) { + throw new DataSourceNotFoundException( + String.format("DataSource with name %s doesn't exist.", dataSourceName)); + } else { + return dataSourceMetadataOptional.get(); + } } } diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 56d3586c6e..eb28495541 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -359,11 +359,11 @@ void testRemovalOfAuthorizationInfo() { @Test void testGetDataSourceMetadataForNonExistingDataSource() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")).thenReturn(Optional.empty()); - IllegalArgumentException exception = + DataSourceNotFoundException exception = assertThrows( - IllegalArgumentException.class, + DataSourceNotFoundException.class, () -> dataSourceService.getDataSourceMetadata("testDS")); - assertEquals("DataSource with name: testDS doesn't exist.", exception.getMessage()); + assertEquals("DataSource with name testDS doesn't exist.", exception.getMessage()); } @Test @@ -385,4 +385,28 @@ void testGetDataSourceMetadataForSpecificDataSourceName() { assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); verify(dataSourceMetadataStorage, times(1)).getDataSourceMetadata("testDS"); } + + @Test + void testGetRawDataSourceMetadata() { + HashMap properties = new HashMap<>(); + properties.put("prometheus.uri", "https://localhost:9090"); + properties.put("prometheus.auth.type", "basicauth"); + properties.put("prometheus.auth.username", "username"); + properties.put("prometheus.auth.password", "password"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testDS", + DataSourceType.PROMETHEUS, + Collections.singletonList("prometheus_access"), + properties); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getRawDataSourceMetadata("testDS"); + assertEquals("testDS", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.password")); + } } diff --git a/docs/user/interfaces/jobinterface.rst b/docs/user/interfaces/jobinterface.rst new file mode 100644 index 0000000000..9c0f66f01f --- /dev/null +++ b/docs/user/interfaces/jobinterface.rst @@ -0,0 +1,18 @@ +.. highlight:: sh + +======== +Endpoint +======== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +For supporting s3Glue and Cloudwatch datasources, we have introduced a new Execution Engine on top of Spark. +All the queries to be executed on spark execution engine are exposed via Job APIs. \ No newline at end of file diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 0404900450..dc92f9ebb3 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -162,6 +162,7 @@ configurations.all { resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk7:1.5.31" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" + resolutionStrategy.force "com.amazonaws:aws-java-sdk-core:1.12.545" } configurations { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 48ceacaf10..76bda07607 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -129,6 +129,12 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_ENGINE_CONFIG = + Setting.simpleString( + Key.SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue(), + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -193,6 +199,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.DATASOURCES_URI_HOSTS_DENY_LIST, DATASOURCE_URI_HOSTS_DENY_LIST, new Updater(Key.DATASOURCES_URI_HOSTS_DENY_LIST)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_ENGINE_CONFIG, + SPARK_EXECUTION_ENGINE_CONFIG, + new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -257,6 +269,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_WINDOW_SETTING) .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) + .add(SPARK_EXECUTION_ENGINE_CONFIG) .build(); } diff --git a/plugin/build.gradle b/plugin/build.gradle index 9e2011059d..53d2e21f10 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -152,8 +152,8 @@ dependencies { testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.12.13' testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' - testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.4.0' - testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.4.0' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.5.0' + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.5.0' testImplementation 'org.junit.jupiter:junit-jupiter:5.6.2' } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 80e1a6b1a3..98a30a4b2c 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -7,8 +7,13 @@ import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -83,6 +88,15 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; +import org.opensearch.sql.spark.client.EmrServerlessClient; +import org.opensearch.sql.spark.client.EmrServerlessClientImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.jobs.JobMetadataStorageService; +import org.opensearch.sql.spark.jobs.OpensearchJobMetadataStorageService; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestJobManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; @@ -110,6 +124,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; + private JobExecutorService jobExecutorService; private Injector injector; public String name() { @@ -202,6 +217,7 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); + this.jobExecutorService = createJobManagementService(); ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); @@ -213,7 +229,7 @@ public Collection createComponents( }); injector = modules.createInjector(); - return ImmutableList.of(dataSourceService); + return ImmutableList.of(dataSourceService, jobExecutorService); } @Override @@ -270,4 +286,35 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } + + private JobExecutorService createJobManagementService() { + JobMetadataStorageService jobMetadataStorageService = + new OpensearchJobMetadataStorageService(client, clusterService); + EmrServerlessClient emrServerlessClient = createEMRServerlessClient(); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, this.dataSourceService, jobExecutionResponseReader); + return new JobExecutorServiceImpl( + jobMetadataStorageService, sparkQueryDispatcher, pluginSettings); + } + + private EmrServerlessClient createEMRServerlessClient() { + String sparkExecutionEngineConfigString = + this.pluginSettings.getSettingValue( + org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString); + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(sparkExecutionEngineConfig.getRegion()) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index aec517aa84..fcf70c01f9 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -15,4 +15,13 @@ grant { // ml-commons client permission java.lang.RuntimePermission "setContextClassLoader"; + + // aws credentials + permission java.io.FilePermission "${user.home}${/}.aws${/}*", "read"; + + // Permissions for aws emr servless sdk + permission javax.management.MBeanServerPermission "createMBeanServer"; + permission javax.management.MBeanServerPermission "findMBeanServer"; + permission javax.management.MBeanPermission "com.amazonaws.metrics.*", "*"; + permission javax.management.MBeanTrustPermission "register"; }; diff --git a/spark/build.gradle b/spark/build.gradle index b93e3327ce..cdefd507fb 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -15,11 +15,14 @@ repositories { dependencies { api project(':core') + implementation project(':protocol') implementation project(':datasources') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.json', name: 'json', version: '20230227' - implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' + implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' @@ -56,7 +59,10 @@ jacocoTestCoverageVerification { excludes = [ 'org.opensearch.sql.spark.data.constants.*', 'org.opensearch.sql.spark.rest.*', - 'org.opensearch.sql.spark.transport.model.*' + 'org.opensearch.sql.spark.transport.model.*', + 'org.opensearch.sql.spark.jobs.model.*', + 'org.opensearch.sql.spark.jobs.config.*', + 'org.opensearch.sql.spark.jobs.execution.*' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java index 1a3304994b..4e66cd9a00 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.client; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; @@ -74,7 +74,7 @@ void runEmrApplication(String query) { flint.getFlintIntegrationJar(), sparkApplicationJar, query, - SPARK_INDEX_NAME, + SPARK_RESPONSE_BUFFER_INDEX_NAME, flint.getFlintHost(), flint.getFlintPort(), flint.getFlintScheme(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java new file mode 100644 index 0000000000..d29198b735 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClient.java @@ -0,0 +1,15 @@ +package org.opensearch.sql.spark.client; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; + +public interface EmrServerlessClient { + + String startJobRun( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams); + + GetJobRunResult getJobRunResult(String applicationId, String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java new file mode 100644 index 0000000000..07672c9348 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.GetJobRunRequest; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobDriver; +import com.amazonaws.services.emrserverless.model.SparkSubmit; +import com.amazonaws.services.emrserverless.model.StartJobRunRequest; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import java.security.AccessController; +import java.security.PrivilegedAction; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class EmrServerlessClientImpl implements EmrServerlessClient { + + private final AWSEMRServerless emrServerless; + private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class); + + public EmrServerlessClientImpl(AWSEMRServerless emrServerless) { + this.emrServerless = emrServerless; + } + + @Override + public String startJobRun( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams) { + StartJobRunRequest request = + new StartJobRunRequest() + .withName(jobName) + .withApplicationId(applicationId) + .withExecutionRoleArn(executionRoleArn) + .withJobDriver( + new JobDriver() + .withSparkSubmit( + new SparkSubmit() + .withEntryPoint(SPARK_SQL_APPLICATION_JAR) + .withEntryPointArguments(query, SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withSparkSubmitParameters(sparkSubmitParams))); + StartJobRunResult startJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.startJobRun(request)); + logger.info("Job Run ID: " + startJobRunResult.getJobRunId()); + return startJobRunResult.getJobRunId(); + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + GetJobRunRequest request = + new GetJobRunRequest().withApplicationId(applicationId).withJobRunId(jobId); + GetJobRunResult getJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.getJobRun(request)); + logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); + return getJobRunResult; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java new file mode 100644 index 0000000000..4f928c4f1f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.google.gson.Gson; +import lombok.Data; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class SparkExecutionEngineConfig { + private String applicationId; + private String region; + private String executionRoleARN; + + public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) { + return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 65d5a01ba2..9cea18a63c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -7,14 +7,62 @@ public class SparkConstants { public static final String EMR = "emr"; + public static final String EMRS = "emr-serverless"; public static final String STEP_ID_FIELD = "stepId.keyword"; - public static final String SPARK_SQL_APPLICATION_JAR = "s3://spark-datasource/sql-job.jar"; - public static final String SPARK_INDEX_NAME = ".query_execution_result"; + public static final String SPARK_SQL_APPLICATION_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; + public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + public static final String GLUE_CATALOG_HIVE_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar"; + public static final String FLINT_CATALOG_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar"; public static final String FLINT_DEFAULT_HOST = "localhost"; public static final String FLINT_DEFAULT_PORT = "9200"; public static final String FLINT_DEFAULT_SCHEME = "http"; public static final String FLINT_DEFAULT_AUTH = "-1"; public static final String FLINT_DEFAULT_REGION = "us-west-2"; + public static final String DEFAULT_CLASS_NAME = "org.opensearch.sql.FlintJob"; + public static final String S3_AWS_CREDENTIALS_PROVIDER_KEY = + "spark.hadoop.fs.s3.customAWSCredentialsProvider"; + public static final String DRIVER_ENV_ASSUME_ROLE_ARN_KEY = + "spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN"; + public static final String EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY = + "spark.emr-serverless.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN"; + public static final String HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY = + "spark.hadoop.aws.catalog.credentials.provider.factory.class"; + public static final String HIVE_METASTORE_GLUE_ARN_KEY = "spark.hive.metastore.glue.role.arn"; + public static final String SPARK_JARS_KEY = "spark.jars"; + public static final String SPARK_JAR_PACKAGES_KEY = "spark.jars.packages"; + public static final String SPARK_JAR_REPOSITORIES_KEY = "spark.jars.repositories"; + public static final String SPARK_DRIVER_ENV_JAVA_HOME_KEY = + "spark.emr-serverless.driverEnv.JAVA_HOME"; + public static final String SPARK_EXECUTOR_ENV_JAVA_HOME_KEY = "spark.executorEnv.JAVA_HOME"; + public static final String FLINT_INDEX_STORE_HOST_KEY = "spark.datasource.flint.host"; + public static final String FLINT_INDEX_STORE_PORT_KEY = "spark.datasource.flint.port"; + public static final String FLINT_INDEX_STORE_SCHEME_KEY = "spark.datasource.flint.scheme"; + public static final String FLINT_INDEX_STORE_AUTH_KEY = "spark.datasource.flint.auth"; + public static final String FLINT_INDEX_STORE_AWSREGION_KEY = "spark.datasource.flint.region"; + public static final String FLINT_CREDENTIALS_PROVIDER_KEY = + "spark.datasource.flint.customAWSCredentialsProvider"; + public static final String SPARK_SQL_EXTENSIONS_KEY = "spark.sql.extensions"; + public static final String HIVE_METASTORE_CLASS_KEY = + "spark.hadoop.hive.metastore.client.factory.class"; + public static final String DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE = + "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; + public static final String DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY = + "com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory"; + public static final String SPARK_STANDALONE_PACKAGE = + "org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT"; + public static final String AWS_SNAPSHOT_REPOSITORY = + "https://aws.oss.sonatype.org/content/repositories/snapshots"; + public static final String GLUE_HIVE_CATALOG_FACTORY_CLASS = + "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"; + public static final String FLINT_DELEGATE_CATALOG = "org.opensearch.sql.FlintDelegateCatalog"; + public static final String FLINT_SQL_EXTENSION = + "org.opensearch.flint.spark.FlintSparkExtensions"; + public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = + "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; + public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java new file mode 100644 index 0000000000..26964dd1b8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.net.URI; +import java.net.URISyntaxException; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.client.EmrServerlessClient; +import org.opensearch.sql.spark.jobs.model.S3GlueSparkSubmitParameters; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +/** This class takes care of understanding query and dispatching job query to emr serverless. */ +@AllArgsConstructor +public class SparkQueryDispatcher { + + private EmrServerlessClient emrServerlessClient; + + private DataSourceService dataSourceService; + + private JobExecutionResponseReader jobExecutionResponseReader; + + public String dispatch(String applicationId, String query, String executionRoleARN) { + String datasourceName = getDataSourceName(); + try { + return emrServerlessClient.startJobRun( + query, + "flint-opensearch-query", + applicationId, + executionRoleARN, + constructSparkParameters(datasourceName)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } + } + + public JSONObject getQueryResponse(String applicationId, String jobId) { + GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult(applicationId, jobId); + JSONObject result = new JSONObject(); + if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { + result = jobExecutionResponseReader.getResultFromOpensearchIndex(jobId); + } + result.put("status", getJobRunResult.getJobRun().getState()); + return result; + } + + // TODO: Analyze given query + // Extract datasourceName + // Apply Authorizaiton. + private String getDataSourceName() { + return "my_glue"; + } + + private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { + return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + } + + private String constructSparkParameters(String datasourceName) throws URISyntaxException { + DataSourceMetadata dataSourceMetadata = + dataSourceService.getRawDataSourceMetadata(datasourceName); + S3GlueSparkSubmitParameters s3GlueSparkSubmitParameters = new S3GlueSparkSubmitParameters(); + s3GlueSparkSubmitParameters.addParameter( + DRIVER_ENV_ASSUME_ROLE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + s3GlueSparkSubmitParameters.addParameter( + EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + s3GlueSparkSubmitParameters.addParameter( + HIVE_METASTORE_GLUE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + String opensearchuri = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.uri"); + URI uri = new URI(opensearchuri); + String auth = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.auth"); + String region = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.region"); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); + s3GlueSparkSubmitParameters.addParameter( + FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_AUTH_KEY, auth); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_AWSREGION_KEY, region); + s3GlueSparkSubmitParameters.addParameter( + "spark.sql.catalog." + datasourceName, FLINT_DELEGATE_CATALOG); + return s3GlueSparkSubmitParameters.toString(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java index 823ad2da29..77783c436f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -15,7 +15,6 @@ import org.json.JSONObject; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; -import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; @@ -81,7 +80,7 @@ private static LinkedHashMap extractRow( } else if (type == ExprCoreType.FLOAT) { linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); } else if (type == ExprCoreType.DATE) { - linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); } else if (type == ExprCoreType.TIMESTAMP) { linkedHashMap.put( column.getName(), new ExprTimestampValue(row.getString(column.getName()))); diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java new file mode 100644 index 0000000000..d59c900b14 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorService.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import org.opensearch.sql.spark.jobs.model.JobExecutionResponse; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; + +/** JobExecutorService exposes functionality to create job, cancel job and get results of a job. */ +public interface JobExecutorService { + + /** + * Creates job based on the request and returns jobId in the response. + * + * @param createJobRequest createJobRequest. + * @return {@link CreateJobResponse} + */ + CreateJobResponse createJob(CreateJobRequest createJobRequest); + + /** + * Returns job execution response for a given jobId. + * + * @param jobId jobId. + * @return {@link JobExecutionResponse} + */ + JobExecutionResponse getJobResults(String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java new file mode 100644 index 0000000000..614f90c1b9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImpl.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.jobs.exceptions.JobNotFoundException; +import org.opensearch.sql.spark.jobs.model.JobExecutionResponse; +import org.opensearch.sql.spark.jobs.model.JobMetadata; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; + +/** JobExecutorService implementation of {@link JobExecutorService}. */ +@AllArgsConstructor +public class JobExecutorServiceImpl implements JobExecutorService { + private JobMetadataStorageService jobMetadataStorageService; + private SparkQueryDispatcher sparkQueryDispatcher; + private Settings settings; + + @Override + public CreateJobResponse createJob(CreateJobRequest createJobRequest) { + String sparkExecutionEngineConfigString = + settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString)); + String jobId = + sparkQueryDispatcher.dispatch( + sparkExecutionEngineConfig.getApplicationId(), + createJobRequest.getQuery(), + sparkExecutionEngineConfig.getExecutionRoleARN()); + jobMetadataStorageService.storeJobMetadata( + new JobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); + return new CreateJobResponse(jobId); + } + + @Override + public JobExecutionResponse getJobResults(String jobId) { + Optional jobMetadata = jobMetadataStorageService.getJobMetadata(jobId); + if (jobMetadata.isPresent()) { + JSONObject jsonObject = + sparkQueryDispatcher.getQueryResponse( + jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId()); + if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) { + DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = + new DefaultSparkSqlFunctionResponseHandle(jsonObject); + List result = new ArrayList<>(); + while (sparkSqlFunctionResponseHandle.hasNext()) { + result.add(sparkSqlFunctionResponseHandle.next()); + } + return new JobExecutionResponse( + JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result); + } else { + return new JobExecutionResponse(jsonObject.getString("status"), null, null); + } + } + throw new JobNotFoundException(String.format("JobId: %s not found", jobId)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java new file mode 100644 index 0000000000..52873d4c25 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/JobMetadataStorageService.java @@ -0,0 +1,11 @@ +package org.opensearch.sql.spark.jobs; + +import java.util.Optional; +import org.opensearch.sql.spark.jobs.model.JobMetadata; + +public interface JobMetadataStorageService { + + void storeJobMetadata(JobMetadata jobMetadata); + + Optional getJobMetadata(String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java new file mode 100644 index 0000000000..a4214289e3 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageService.java @@ -0,0 +1,161 @@ +package org.opensearch.sql.spark.jobs; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.spark.jobs.model.JobMetadata; + +public class OpensearchJobMetadataStorageService implements JobMetadataStorageService { + + public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; + private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = + "job-metadata-index-mapping.yml"; + private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = + "job-metadata-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + private final Client client; + private final ClusterService clusterService; + + /** + * This class implements JobMetadataStorageService interface using OpenSearch as underlying + * storage. + * + * @param client opensearch NodeClient. + * @param clusterService ClusterService. + */ + public OpensearchJobMetadataStorageService(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public void storeJobMetadata(JobMetadata jobMetadata) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createDataSourcesIndex(); + } + IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); + indexRequest.id(jobMetadata.getJobId()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture indexResponseActionFuture; + IndexResponse indexResponse; + try (ThreadContext.StoredContext storedContext = + client.threadPool().getThreadContext().stashContext()) { + indexRequest.source(JobMetadata.convertToXContent(jobMetadata)); + indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("JobMetadata : {} successfully created", jobMetadata.getJobId()); + } else { + throw new RuntimeException( + "Saving job metadata information failed with result : " + + indexResponse.getResult().getLowercase()); + } + } + + @Override + public Optional getJobMetadata(String jobId) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createDataSourcesIndex(); + return Optional.empty(); + } + return searchInDataSourcesIndex(QueryBuilders.termQuery("jobId", jobId)).stream().findFirst(); + } + + private void createDataSourcesIndex() { + try { + InputStream mappingFileStream = + OpensearchJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpensearchJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); + createIndexRequest + .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) + .settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + + JOB_METADATA_INDEX + + " index:: " + + e.getMessage()); + } + } + + private List searchInDataSourcesIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(JOB_METADATA_INDEX); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(1); + searchRequest.source(searchSourceBuilder); + // https://github.com/opensearch-project/sql/issues/1801. + searchRequest.preference("_primary_first"); + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching job metadata information failed with status : " + searchResponse.status()); + } else { + List list = new ArrayList<>(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + String sourceAsString = searchHit.getSourceAsString(); + JobMetadata jobMetadata; + try { + jobMetadata = JobMetadata.toJobMetadata(sourceAsString); + } catch (IOException e) { + throw new RuntimeException(e); + } + list.add(jobMetadata); + } + return list; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/exceptions/JobNotFoundException.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/exceptions/JobNotFoundException.java new file mode 100644 index 0000000000..40ccece071 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/exceptions/JobNotFoundException.java @@ -0,0 +1,15 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.jobs.exceptions; + +/** JobNotFoundException. */ +public class JobNotFoundException extends RuntimeException { + public JobNotFoundException(String message) { + super(message); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobExecutionResponse.java new file mode 100644 index 0000000000..449c7ee2b4 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobExecutionResponse.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.spark.jobs.model; + +import java.util.List; +import lombok.Data; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; + +@Data +public class JobExecutionResponse { + private final String status; + private final ExecutionEngine.Schema schema; + private final List results; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java new file mode 100644 index 0000000000..d9f628b476 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/JobMetadata.java @@ -0,0 +1,93 @@ +package org.opensearch.sql.spark.jobs.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import com.google.gson.Gson; +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Data; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +/** This class models all the metadata required for a job. */ +@Data +@AllArgsConstructor +public class JobMetadata { + private String jobId; + private String applicationId; + + @Override + public String toString() { + return new Gson().toJson(this); + } + + /** + * Converts JobMetadata to XContentBuilder. + * + * @param metadata metadata. + * @return XContentBuilder {@link XContentBuilder} + * @throws Exception Exception. + */ + public static XContentBuilder convertToXContent(JobMetadata metadata) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("jobId", metadata.getJobId()); + builder.field("applicationId", metadata.getApplicationId()); + builder.endObject(); + return builder; + } + + /** + * Converts json string to DataSourceMetadata. + * + * @param json jsonstring. + * @return jobmetadata {@link JobMetadata} + * @throws java.io.IOException IOException. + */ + public static JobMetadata toJobMetadata(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toJobMetadata(parser); + } + } + + /** + * Convert xcontent parser to JobMetadata. + * + * @param parser parser. + * @return JobMetadata {@link JobMetadata} + * @throws IOException IOException. + */ + public static JobMetadata toJobMetadata(XContentParser parser) throws IOException { + String jobId = null; + String applicationId = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case "jobId": + jobId = parser.textOrNull(); + break; + case "applicationId": + applicationId = parser.textOrNull(); + break; + default: + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + if (jobId == null || applicationId == null) { + throw new IllegalArgumentException("jobId and applicationId are required fields."); + } + return new JobMetadata(jobId, applicationId); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/jobs/model/S3GlueSparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/S3GlueSparkSubmitParameters.java new file mode 100644 index 0000000000..e8a4985557 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/jobs/model/S3GlueSparkSubmitParameters.java @@ -0,0 +1,90 @@ +package org.opensearch.sql.spark.jobs.model; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.AWS_SNAPSHOT_REPOSITORY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CATALOG_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_REGION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SQL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_CATALOG_HIVE_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_HIVE_CATALOG_FACTORY_CLASS; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_CLASS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JAVA_HOME_LOCATION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.S3_AWS_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_PACKAGES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_REPOSITORIES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_EXTENSIONS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_STANDALONE_PACKAGE; + +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class S3GlueSparkSubmitParameters { + + private String className; + private Map config; + public static final String SPACE = " "; + public static final String EQUALS = "="; + + public S3GlueSparkSubmitParameters() { + this.className = DEFAULT_CLASS_NAME; + this.config = new LinkedHashMap<>(); + this.config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); + this.config.put( + HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, + DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + this.config.put(SPARK_JARS_KEY, GLUE_CATALOG_HIVE_JAR + "," + FLINT_CATALOG_JAR); + this.config.put(SPARK_JAR_PACKAGES_KEY, SPARK_STANDALONE_PACKAGE); + this.config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); + this.config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + this.config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + this.config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); + this.config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); + this.config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); + this.config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); + this.config.put(FLINT_INDEX_STORE_AWSREGION_KEY, FLINT_DEFAULT_REGION); + this.config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); + this.config.put(SPARK_SQL_EXTENSIONS_KEY, FLINT_SQL_EXTENSION); + this.config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + } + + public void addParameter(String key, String value) { + this.config.put(key, value); + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(" --class "); + stringBuilder.append(this.className); + stringBuilder.append(SPACE); + for (String key : config.keySet()) { + stringBuilder.append(" --conf "); + stringBuilder.append(key); + stringBuilder.append(EQUALS); + stringBuilder.append(config.get(key)); + stringBuilder.append(SPACE); + } + return stringBuilder.toString(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java new file mode 100644 index 0000000000..8abb7cd11f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class JobExecutionResponseReader { + private final Client client; + private static final Logger LOG = LogManager.getLogger(); + + /** + * JobExecutionResponseReader for spark query. + * + * @param client Opensearch client + */ + public JobExecutionResponseReader(Client client) { + this.client = client; + } + + public JSONObject getResultFromOpensearchIndex(String jobId) { + return searchInSparkIndex(QueryBuilders.termQuery(STEP_ID_FIELD, jobId)); + } + + private JSONObject searchInSparkIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + ActionFuture searchResponseActionFuture; + try { + searchResponseActionFuture = client.search(searchRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching result from " + + SPARK_RESPONSE_BUFFER_INDEX_NAME + + " index failed with status : " + + searchResponse.status()); + } else { + JSONObject data = new JSONObject(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + data.put("data", searchHit.getSourceAsMap()); + } + return data; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java index 3edb541384..496caba2c9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.response; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import com.google.common.annotations.VisibleForTesting; import lombok.Data; @@ -51,7 +51,7 @@ public JSONObject getResultFromOpensearchIndex() { private JSONObject searchInSparkIndex(QueryBuilder query) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(SPARK_INDEX_NAME); + searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); @@ -65,7 +65,7 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { if (searchResponse.status().getStatus() != 200) { throw new RuntimeException( "Fetching result from " - + SPARK_INDEX_NAME + + SPARK_RESPONSE_BUFFER_INDEX_NAME + " index failed with status : " + searchResponse.status()); } else { @@ -80,7 +80,7 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { @VisibleForTesting void deleteInSparkIndex(String id) { - DeleteRequest deleteRequest = new DeleteRequest(SPARK_INDEX_NAME); + DeleteRequest deleteRequest = new DeleteRequest(SPARK_RESPONSE_BUFFER_INDEX_NAME); deleteRequest.id(id); ActionFuture deleteResponseActionFuture; try { diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java index 669cbb6aca..f386dfb7b6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java @@ -138,7 +138,7 @@ public void onResponse(CreateJobActionResponse createJobActionResponse) { new BytesRestResponse( RestStatus.CREATED, "application/json; charset=UTF-8", - submitJobRequest.getQuery())); + createJobActionResponse.getResult())); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java new file mode 100644 index 0000000000..9f4990de34 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class CreateJobResponse { + private String jobId; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java index 53ae9fad90..35e212d773 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java @@ -12,6 +12,11 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; import org.opensearch.tasks.Task; @@ -20,20 +25,37 @@ public class TransportCreateJobRequestAction extends HandledTransportAction { + private final JobExecutorService jobExecutorService; + public static final String NAME = "cluster:admin/opensearch/ql/jobs/create"; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CreateJobActionResponse::new); @Inject public TransportCreateJobRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + JobExecutorServiceImpl jobManagementService) { super(NAME, transportService, actionFilters, CreateJobActionRequest::new); + this.jobExecutorService = jobManagementService; } @Override protected void doExecute( Task task, CreateJobActionRequest request, ActionListener listener) { - String responseContent = "submitted_job"; - listener.onResponse(new CreateJobActionResponse(responseContent)); + try { + CreateJobRequest createJobRequest = request.getCreateJobRequest(); + CreateJobResponse createJobResponse = jobExecutorService.createJob(createJobRequest); + String responseContent = + new JsonResponseFormatter(JsonResponseFormatter.Style.PRETTY) { + @Override + protected Object buildJsonObject(CreateJobResponse response) { + return response; + } + }.format(createJobResponse); + listener.onResponse(new CreateJobActionResponse(responseContent)); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java index 6aba1b48b6..2237cdf489 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java @@ -7,11 +7,20 @@ package org.opensearch.sql.spark.transport; +import org.json.JSONObject; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.protocol.response.QueryResult; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.protocol.response.format.ResponseFormatter; +import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; +import org.opensearch.sql.spark.jobs.JobExecutorService; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.jobs.model.JobExecutionResponse; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; import org.opensearch.tasks.Task; @@ -21,14 +30,19 @@ public class TransportGetQueryResultRequestAction extends HandledTransportAction< GetJobQueryResultActionRequest, GetJobQueryResultActionResponse> { + private final JobExecutorService jobExecutorService; + public static final String NAME = "cluster:admin/opensearch/ql/jobs/result"; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, GetJobQueryResultActionResponse::new); @Inject public TransportGetQueryResultRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + JobExecutorServiceImpl jobManagementService) { super(NAME, transportService, actionFilters, GetJobQueryResultActionRequest::new); + this.jobExecutorService = jobManagementService; } @Override @@ -36,7 +50,26 @@ protected void doExecute( Task task, GetJobQueryResultActionRequest request, ActionListener listener) { - String responseContent = "job result"; - listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); + try { + String jobId = request.getJobId(); + JobExecutionResponse jobExecutionResponse = jobExecutorService.getJobResults(jobId); + if (!jobExecutionResponse.getStatus().equals("SUCCESS")) { + JSONObject jsonObject = new JSONObject(); + jsonObject.put("status", jobExecutionResponse.getStatus()); + listener.onResponse(new GetJobQueryResultActionResponse(jsonObject.toString())); + } else { + ResponseFormatter formatter = + new SimpleJsonResponseFormatter(JsonResponseFormatter.Style.PRETTY); + String responseContent = + formatter.format( + new QueryResult( + jobExecutionResponse.getSchema(), + jobExecutionResponse.getResults(), + Cursor.None)); + listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); + } + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml new file mode 100644 index 0000000000..ec2c83a4df --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -0,0 +1,20 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + jobId: + type: text + fields: + keyword: + type: keyword + applicationId: + type: text + fields: + keyword: + type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml new file mode 100644 index 0000000000..be93f4645c --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java new file mode 100644 index 0000000000..36f10cd08b --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -0,0 +1,48 @@ +/* Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class EmrServerlessClientImplTest { + @Mock private AWSEMRServerless emrServerless; + + @Test + void testStartJobRun() { + StartJobRunResult response = new StartJobRunResult(); + when(emrServerless.startJobRun(any())).thenReturn(response); + + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + emrServerlessClient.startJobRun( + QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS); + } + + @Test + void testGetJobRunState() { + JobRun jobRun = new JobRun(); + jobRun.setState("Running"); + GetJobRunResult response = new GetJobRunResult(); + response.setJobRun(jobRun); + when(emrServerless.getJobRun(any())).thenReturn(response); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 2b1020568a..e455e6a049 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -7,5 +7,12 @@ public class TestConstants { public static final String QUERY = "select 1"; + public static final String TEST_DATASOURCE_NAME = "test_datasource_name"; public static final String EMR_CLUSTER_ID = "j-123456789"; + public static final String EMR_JOB_ID = "job-123xxx"; + public static final String EMRS_APPLICATION_ID = "app-xxxxx"; + public static final String EMRS_EXECUTION_ROLE = "execution_role"; + public static final String EMRS_DATASOURCE_ROLE = "datasource_role"; + public static final String EMRS_JOB_NAME = "job_name"; + public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java new file mode 100644 index 0000000000..8bd6f1caa8 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.client.EmrServerlessClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@ExtendWith(MockitoExtension.class) +public class SparkQueryDispatcherTest { + + @Mock private EmrServerlessClient emrServerlessClient; + @Mock private DataSourceService dataSourceService; + @Mock private JobExecutionResponseReader jobExecutionResponseReader; + + @Test + void testDispatch() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, dataSourceService, jobExecutionResponseReader); + when(emrServerlessClient.startJobRun( + QUERY, + "flint-opensearch-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString())) + .thenReturn(EMR_JOB_ID); + when(dataSourceService.getRawDataSourceMetadata("my_glue")) + .thenReturn(constructMyGlueDataSourceMetadata()); + String jobId = sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE); + verify(emrServerlessClient, times(1)) + .startJobRun( + QUERY, + "flint-opensearch-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString()); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + + @Test + void testDispatchWithWrongURI() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, dataSourceService, jobExecutionResponseReader); + when(dataSourceService.getRawDataSourceMetadata("my_glue")) + .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE)); + Assertions.assertEquals( + "Bad URI in indexstore configuration of the : my_glue datasoure.", + illegalArgumentException.getMessage()); + } + + private DataSourceMetadata constructMyGlueDataSourceMetadata() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put( + "glue.indexstore.opensearch.uri", + "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + @Test + void testGetQueryResponse() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, dataSourceService, jobExecutionResponseReader); + when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); + JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals("PENDING", result.get("status")); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testGetQueryResponseWithSuccess() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, dataSourceService, jobExecutionResponseReader); + when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); + JSONObject queryResult = new JSONObject(); + queryResult.put("data", "result"); + when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) + .thenReturn(queryResult); + JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(emrServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); + Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); + Assertions.assertEquals("result", result.get("data")); + Assertions.assertEquals("SUCCESS", result.get("status")); + } + + String constructExpectedSparkSubmitParameterString() { + return " --class org.opensearch.sql.FlintJob --conf" + + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf" + + " spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" + + " --conf" + + " spark.jars=s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar,s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar" + + " --conf" + + " spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT" + + " --conf" + + " spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots" + + " --conf" + + " spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf" + + " spark.datasource.flint.host=search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com" + + " --conf spark.datasource.flint.port=-1 --conf" + + " spark.datasource.flint.scheme=https --conf spark.datasource.flint.auth=sigv4 " + + " --conf spark.datasource.flint.region=eu-west-1 --conf" + + " spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions " + + " --conf" + + " spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory" + + " --conf" + + " spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf" + + " spark.emr-serverless.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf" + + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImplTest.java new file mode 100644 index 0000000000..205e47a3a9 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/jobs/JobExecutorServiceImplTest.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.io.IOException; +import java.util.HashMap; +import java.util.Optional; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.jobs.exceptions.JobNotFoundException; +import org.opensearch.sql.spark.jobs.model.JobExecutionResponse; +import org.opensearch.sql.spark.jobs.model.JobMetadata; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; + +@ExtendWith(MockitoExtension.class) +public class JobExecutorServiceImplTest { + + @Mock private SparkQueryDispatcher sparkQueryDispatcher; + @Mock private JobMetadataStorageService jobMetadataStorageService; + @Mock private Settings settings; + + @Test + void testCreateJob() { + JobExecutorServiceImpl jobExecutorService = + new JobExecutorServiceImpl(jobMetadataStorageService, sparkQueryDispatcher, settings); + CreateJobRequest createJobRequest = + new CreateJobRequest("select * from my_glue.default.http_logs"); + when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + .thenReturn( + "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); + when(sparkQueryDispatcher.dispatch( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + "arn:aws:iam::270824043731:role/emr-job-execution-role")) + .thenReturn(EMR_JOB_ID); + CreateJobResponse createJobResponse = jobExecutorService.createJob(createJobRequest); + verify(jobMetadataStorageService, times(1)) + .storeJobMetadata(new JobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p")); + verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + verify(sparkQueryDispatcher, times(1)) + .dispatch( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + "arn:aws:iam::270824043731:role/emr-job-execution-role"); + Assertions.assertEquals(EMR_JOB_ID, createJobResponse.getJobId()); + } + + @Test + void testGetJobResultsWithJobNotFoundException() { + JobExecutorServiceImpl jobExecutorService = + new JobExecutorServiceImpl(jobMetadataStorageService, sparkQueryDispatcher, settings); + when(jobMetadataStorageService.getJobMetadata(EMR_JOB_ID)).thenReturn(Optional.empty()); + JobNotFoundException jobNotFoundException = + Assertions.assertThrows( + JobNotFoundException.class, () -> jobExecutorService.getJobResults(EMR_JOB_ID)); + Assertions.assertEquals( + "JobId: " + EMR_JOB_ID + " not found", jobNotFoundException.getMessage()); + verifyNoInteractions(sparkQueryDispatcher); + verifyNoInteractions(settings); + } + + @Test + void testGetJobResultsWithInProgressJob() { + JobExecutorServiceImpl jobExecutorService = + new JobExecutorServiceImpl(jobMetadataStorageService, sparkQueryDispatcher, settings); + when(jobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + JSONObject jobResult = new JSONObject(); + jobResult.put("status", JobRunState.PENDING.toString()); + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(jobResult); + JobExecutionResponse jobExecutionResponse = jobExecutorService.getJobResults(EMR_JOB_ID); + + Assertions.assertNull(jobExecutionResponse.getResults()); + Assertions.assertNull(jobExecutionResponse.getSchema()); + Assertions.assertEquals("PENDING", jobExecutionResponse.getStatus()); + verifyNoInteractions(settings); + } + + @Test + void testGetJobResultsWithSuccessJob() throws IOException { + when(jobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); + jobResult.put("status", JobRunState.SUCCESS.toString()); + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(jobResult); + + JobExecutorServiceImpl jobExecutorService = + new JobExecutorServiceImpl(jobMetadataStorageService, sparkQueryDispatcher, settings); + JobExecutionResponse jobExecutionResponse = jobExecutorService.getJobResults(EMR_JOB_ID); + + Assertions.assertEquals("SUCCESS", jobExecutionResponse.getStatus()); + Assertions.assertEquals(1, jobExecutionResponse.getSchema().getColumns().size()); + Assertions.assertEquals("1", jobExecutionResponse.getSchema().getColumns().get(0).getName()); + Assertions.assertEquals( + 1, ((HashMap) jobExecutionResponse.getResults().get(0).value()).get("1")); + verifyNoInteractions(settings); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageServiceTest.java new file mode 100644 index 0000000000..1c599b6c80 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/jobs/OpensearchJobMetadataStorageServiceTest.java @@ -0,0 +1,238 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.jobs; + +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.jobs.OpensearchJobMetadataStorageService.JOB_METADATA_INDEX; + +import java.util.Optional; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.ArgumentMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.spark.jobs.model.JobMetadata; + +@ExtendWith(MockitoExtension.class) +public class OpensearchJobMetadataStorageServiceTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private SearchResponse searchResponse; + + @Mock private ActionFuture searchResponseActionFuture; + @Mock private ActionFuture createIndexResponseActionFuture; + @Mock private ActionFuture indexResponseActionFuture; + @Mock private IndexResponse indexResponse; + @Mock private SearchHit searchHit; + @InjectMocks private OpensearchJobMetadataStorageService opensearchJobMetadataStorageService; + + @Test + public void testStoreJobMetadata() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + + this.opensearchJobMetadataStorageService.storeJobMetadata(jobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithOutCreatingIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.TRUE); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + + this.opensearchJobMetadataStorageService.storeJobMetadata(jobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithException() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())) + .thenThrow(new RuntimeException("error while indexing")); + + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(jobMetadata)); + Assertions.assertEquals( + "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithIndexCreationFailed() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); + + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(jobMetadata)); + Assertions.assertEquals( + "Internal server error while creating.ql-job-metadata index:: " + + "Index creation is not acknowledged.", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataFailedWithNotFoundResponse() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(jobMetadata)); + Assertions.assertEquals( + "Saving job metadata information failed with result : not_found", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testGetJobMetadata() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + JobMetadata jobMetadata = new JobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + Mockito.when(searchHit.getSourceAsString()).thenReturn(jobMetadata.toString()); + + Optional jobMetadataOptional = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + Assertions.assertTrue(jobMetadataOptional.isPresent()); + Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); + Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + } + + @Test + public void testGetJobMetadataWith404SearchResponse() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); + + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + Assertions.assertEquals( + "Fetching job metadata information failed with status : NOT_FOUND", + runtimeException.getMessage()); + } + + @Test + public void testGetJobMetadataWithParsingFailed() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); + + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + } + + @Test + public void testGetJobMetadataWithNoIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + + Optional jobMetadata = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + + Assertions.assertFalse(jobMetadata.isPresent()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/JobExecutionResponseReaderTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/JobExecutionResponseReaderTest.java new file mode 100644 index 0000000000..87b830c9e9 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/response/JobExecutionResponseReaderTest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; + +import java.util.Map; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; + +@ExtendWith(MockitoExtension.class) +public class JobExecutionResponseReaderTest { + @Mock private Client client; + @Mock private SearchResponse searchResponse; + @Mock private SearchHit searchHit; + @Mock private ActionFuture searchResponseActionFuture; + + @Test + public void testGetResultFromOpensearchIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertFalse(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID).isEmpty()); + } + + @Test + public void testInvalidSearchResponse() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); + + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + Assertions.assertEquals( + "Fetching result from " + + SPARK_RESPONSE_BUFFER_INDEX_NAME + + " index failed with status : " + + RestStatus.NO_CONTENT, + exception.getMessage()); + } + + @Test + public void testSearchFailure() { + when(client.search(any())).thenThrow(RuntimeException.class); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertThrows( + RuntimeException.class, + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java index 211561ac72..e234454021 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java @@ -10,7 +10,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import java.util.Map; import org.apache.lucene.search.TotalHits; @@ -69,7 +69,7 @@ public void testInvalidSearchResponse() { assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); Assertions.assertEquals( "Fetching result from " - + SPARK_INDEX_NAME + + SPARK_RESPONSE_BUFFER_INDEX_NAME + " index failed with status : " + RestStatus.NO_CONTENT, exception.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java index 4357899368..36f095b668 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java @@ -7,6 +7,11 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -19,7 +24,9 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateJobResponse; import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; import org.opensearch.tasks.Task; @@ -30,26 +37,43 @@ public class TransportCreateJobRequestActionTest { @Mock private TransportService transportService; @Mock private TransportCreateJobRequestAction action; + @Mock private JobExecutorServiceImpl jobExecutorService; @Mock private Task task; @Mock private ActionListener actionListener; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; + @Captor private ArgumentCaptor exceptionArgumentCaptor; @BeforeEach public void setUp() { action = - new TransportCreateJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + new TransportCreateJobRequestAction( + transportService, new ActionFilters(new HashSet<>()), jobExecutorService); } @Test public void testDoExecute() { CreateJobRequest createJobRequest = new CreateJobRequest("source = my_glue.default.alb_logs"); CreateJobActionRequest request = new CreateJobActionRequest(createJobRequest); - + when(jobExecutorService.createJob(createJobRequest)).thenReturn(new CreateJobResponse("123")); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); CreateJobActionResponse createJobActionResponse = createJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("submitted_job", createJobActionResponse.getResult()); + Assertions.assertEquals( + "{\n" + " \"jobId\": \"123\"\n" + "}", createJobActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + CreateJobRequest createJobRequest = new CreateJobRequest("source = my_glue.default.alb_logs"); + CreateJobActionRequest request = new CreateJobActionRequest(createJobRequest); + doThrow(new RuntimeException("Error")).when(jobExecutorService).createJob(createJobRequest); + action.doExecute(task, request, actionListener); + verify(jobExecutorService, times(1)).createJob(createJobRequest); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java index f22adead49..2f61bcff43 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java @@ -7,6 +7,17 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -15,10 +26,13 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.jobs.JobExecutorServiceImpl; +import org.opensearch.sql.spark.jobs.exceptions.JobNotFoundException; +import org.opensearch.sql.spark.jobs.model.JobExecutionResponse; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; import org.opensearch.tasks.Task; @@ -31,24 +45,92 @@ public class TransportGetQueryResultRequestActionTest { @Mock private TransportGetQueryResultRequestAction action; @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private JobExecutorServiceImpl jobExecutorService; @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; + @Captor private ArgumentCaptor exceptionArgumentCaptor; + @BeforeEach public void setUp() { action = new TransportGetQueryResultRequestAction( - transportService, new ActionFilters(new HashSet<>())); + transportService, new ActionFilters(new HashSet<>()), jobExecutorService); } @Test - public void testDoExecuteForSingleJob() { + public void testDoExecute() { GetJobQueryResultActionRequest request = new GetJobQueryResultActionRequest("jobId"); + JobExecutionResponse jobExecutionResponse = new JobExecutionResponse("IN_PROGRESS", null, null); + when(jobExecutorService.getJobResults("jobId")).thenReturn(jobExecutionResponse); action.doExecute(task, request, actionListener); - Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); GetJobQueryResultActionResponse getJobQueryResultActionResponse = createJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("job result", getJobQueryResultActionResponse.getResult()); + Assertions.assertEquals( + "{\"status\":\"IN_PROGRESS\"}", getJobQueryResultActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithSuccessResponse() { + GetJobQueryResultActionRequest request = new GetJobQueryResultActionRequest("jobId"); + ExecutionEngine.Schema schema = + new ExecutionEngine.Schema( + ImmutableList.of( + new ExecutionEngine.Schema.Column("name", "name", STRING), + new ExecutionEngine.Schema.Column("age", "age", INTEGER))); + JobExecutionResponse jobExecutionResponse = + new JobExecutionResponse( + "SUCCESS", + schema, + Arrays.asList( + tupleValue(ImmutableMap.of("name", "John", "age", 20)), + tupleValue(ImmutableMap.of("name", "Smith", "age", 30)))); + when(jobExecutorService.getJobResults("jobId")).thenReturn(jobExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + GetJobQueryResultActionResponse getJobQueryResultActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"name\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"age\",\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " \"John\",\n" + + " 20\n" + + " ],\n" + + " [\n" + + " \"Smith\",\n" + + " 30\n" + + " ]\n" + + " ],\n" + + " \"total\": 2,\n" + + " \"size\": 2\n" + + "}", + getJobQueryResultActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + GetJobQueryResultActionRequest request = new GetJobQueryResultActionRequest("123"); + doThrow(new JobNotFoundException("JobId 123 not found")) + .when(jobExecutorService) + .getJobResults("123"); + action.doExecute(task, request, actionListener); + verify(jobExecutorService, times(1)).getJobResults("123"); + verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("JobId 123 not found", exception.getMessage()); } }