From 75e9f04f29c3831275a474246788d9637cbf73fe Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 8 Sep 2023 17:02:34 -0700 Subject: [PATCH 1/5] Glue datasource support (#2055) Signed-off-by: Vamsi Manohar --- .../sql/datasource/model/DataSourceType.java | 4 +- .../glue/GlueDataSourceFactory.java | 56 +++++++++ .../utils/DatasourceValidationUtils.java | 3 +- .../glue/GlueDataSourceFactoryTest.java | 115 ++++++++++++++++++ .../utils/DatasourceValidationUtilsTest.java | 4 +- .../{ => connectors}/prometheus_connector.rst | 0 .../ppl/admin/connectors/s3glue_connector.rst | 68 +++++++++++ .../{ => connectors}/spark_connector.rst | 0 .../org/opensearch/sql/plugin/SQLPlugin.java | 2 + .../storage/PrometheusStorageFactoryTest.java | 6 +- 10 files changed, 250 insertions(+), 8 deletions(-) create mode 100644 datasources/src/main/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactory.java create mode 100644 datasources/src/test/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactoryTest.java rename docs/user/ppl/admin/{ => connectors}/prometheus_connector.rst (100%) create mode 100644 docs/user/ppl/admin/connectors/s3glue_connector.rst rename docs/user/ppl/admin/{ => connectors}/spark_connector.rst (100%) diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index 5010e41942..a3c7c73d6b 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -8,7 +8,9 @@ public enum DataSourceType { PROMETHEUS("prometheus"), OPENSEARCH("opensearch"), - SPARK("spark"); + SPARK("spark"), + S3GLUE("s3glue"); + private String text; DataSourceType(String text) { diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactory.java b/datasources/src/main/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactory.java new file mode 100644 index 0000000000..24f94376bf --- /dev/null +++ b/datasources/src/main/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactory.java @@ -0,0 +1,56 @@ +package org.opensearch.sql.datasources.glue; + +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.Set; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.utils.DatasourceValidationUtils; +import org.opensearch.sql.storage.DataSourceFactory; + +@RequiredArgsConstructor +public class GlueDataSourceFactory implements DataSourceFactory { + + private final Settings pluginSettings; + + // Glue configuration properties + public static final String GLUE_AUTH_TYPE = "glue.auth.type"; + public static final String GLUE_ROLE_ARN = "glue.auth.role_arn"; + public static final String FLINT_URI = "glue.indexstore.opensearch.uri"; + public static final String FLINT_AUTH = "glue.indexstore.opensearch.auth"; + public static final String FLINT_REGION = "glue.indexstore.opensearch.region"; + + @Override + public DataSourceType getDataSourceType() { + return DataSourceType.S3GLUE; + } + + @Override + public DataSource createDataSource(DataSourceMetadata metadata) { + try { + validateGlueDataSourceConfiguration(metadata.getProperties()); + return new DataSource( + metadata.getName(), + metadata.getConnector(), + (dataSourceSchemaName, tableName) -> { + throw new UnsupportedOperationException("Glue storage engine is not supported."); + }); + } catch (URISyntaxException | UnknownHostException e) { + throw new IllegalArgumentException("Invalid flint host in properties."); + } + } + + private void validateGlueDataSourceConfiguration(Map dataSourceMetadataConfig) + throws URISyntaxException, UnknownHostException { + DatasourceValidationUtils.validateLengthAndRequiredFields( + dataSourceMetadataConfig, + Set.of(GLUE_AUTH_TYPE, GLUE_ROLE_ARN, FLINT_URI, FLINT_REGION, FLINT_AUTH)); + DatasourceValidationUtils.validateHost( + dataSourceMetadataConfig.get(FLINT_URI), + pluginSettings.getSettingValue(Settings.Key.DATASOURCES_URI_HOSTS_DENY_LIST)); + } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtils.java b/datasources/src/main/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtils.java index e779e8e04d..6f03ffb9a4 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtils.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtils.java @@ -40,8 +40,7 @@ public static void validateLengthAndRequiredFields( StringBuilder errorStringBuilder = new StringBuilder(); if (missingFields.size() > 0) { errorStringBuilder.append( - String.format( - "Missing %s fields in the Prometheus connector properties.", missingFields)); + String.format("Missing %s fields in the connector properties.", missingFields)); } if (invalidLengthFields.size() > 0) { diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactoryTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactoryTest.java new file mode 100644 index 0000000000..b018e5f9dc --- /dev/null +++ b/datasources/src/test/java/org/opensearch/sql/datasources/glue/GlueDataSourceFactoryTest.java @@ -0,0 +1,115 @@ +package org.opensearch.sql.datasources.glue; + +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; + +@ExtendWith(MockitoExtension.class) +public class GlueDataSourceFactoryTest { + + @Mock private Settings settings; + + @Test + void testGetConnectorType() { + GlueDataSourceFactory glueDatasourceFactory = new GlueDataSourceFactory(settings); + Assertions.assertEquals(DataSourceType.S3GLUE, glueDatasourceFactory.getDataSourceType()); + } + + @Test + @SneakyThrows + void testCreateGLueDatSource() { + when(settings.getSettingValue(Settings.Key.DATASOURCES_URI_HOSTS_DENY_LIST)) + .thenReturn(Collections.emptyList()); + GlueDataSourceFactory glueDatasourceFactory = new GlueDataSourceFactory(settings); + + DataSourceMetadata metadata = new DataSourceMetadata(); + HashMap properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put("glue.auth.role_arn", "role_arn"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "false"); + properties.put("glue.indexstore.opensearch.region", "us-west-2"); + + metadata.setName("my_glue"); + metadata.setConnector(DataSourceType.S3GLUE); + metadata.setProperties(properties); + DataSource dataSource = glueDatasourceFactory.createDataSource(metadata); + Assertions.assertEquals(DataSourceType.S3GLUE, dataSource.getConnectorType()); + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + dataSource + .getStorageEngine() + .getTable(new DataSourceSchemaName("my_glue", "default"), "alb_logs")); + Assertions.assertEquals( + "Glue storage engine is not supported.", unsupportedOperationException.getMessage()); + } + + @Test + @SneakyThrows + void testCreateGLueDatSourceWithInvalidFlintHost() { + when(settings.getSettingValue(Settings.Key.DATASOURCES_URI_HOSTS_DENY_LIST)) + .thenReturn(List.of("127.0.0.0/8")); + GlueDataSourceFactory glueDatasourceFactory = new GlueDataSourceFactory(settings); + + DataSourceMetadata metadata = new DataSourceMetadata(); + HashMap properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put("glue.auth.role_arn", "role_arn"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "false"); + properties.put("glue.indexstore.opensearch.region", "us-west-2"); + + metadata.setName("my_glue"); + metadata.setConnector(DataSourceType.S3GLUE); + metadata.setProperties(properties); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, () -> glueDatasourceFactory.createDataSource(metadata)); + Assertions.assertEquals( + "Disallowed hostname in the uri. " + + "Validate with plugins.query.datasources.uri.hosts.denylist config", + illegalArgumentException.getMessage()); + } + + @Test + @SneakyThrows + void testCreateGLueDatSourceWithInvalidFlintHostSyntax() { + when(settings.getSettingValue(Settings.Key.DATASOURCES_URI_HOSTS_DENY_LIST)) + .thenReturn(List.of("127.0.0.0/8")); + GlueDataSourceFactory glueDatasourceFactory = new GlueDataSourceFactory(settings); + + DataSourceMetadata metadata = new DataSourceMetadata(); + HashMap properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put("glue.auth.role_arn", "role_arn"); + properties.put( + "glue.indexstore.opensearch.uri", + "http://dummyprometheus.com:9090? paramt::localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "false"); + properties.put("glue.indexstore.opensearch.region", "us-west-2"); + + metadata.setName("my_glue"); + metadata.setConnector(DataSourceType.S3GLUE); + metadata.setProperties(properties); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, () -> glueDatasourceFactory.createDataSource(metadata)); + Assertions.assertEquals( + "Invalid flint host in properties.", illegalArgumentException.getMessage()); + } +} diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtilsTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtilsTest.java index 15e921e72a..2b77c1938a 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtilsTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/utils/DatasourceValidationUtilsTest.java @@ -48,7 +48,7 @@ public void testValidateLengthAndRequiredFieldsWithAbsentField() { DatasourceValidationUtils.validateLengthAndRequiredFields( config, Set.of("s3.uri", "s3.auth.type"))); Assertions.assertEquals( - "Missing [s3.auth.type] fields in the Prometheus connector properties.", + "Missing [s3.auth.type] fields in the connector properties.", illegalArgumentException.getMessage()); } @@ -63,7 +63,7 @@ public void testValidateLengthAndRequiredFieldsWithInvalidLength() { DatasourceValidationUtils.validateLengthAndRequiredFields( config, Set.of("s3.uri", "s3.auth.type"))); Assertions.assertEquals( - "Missing [s3.auth.type] fields in the Prometheus connector properties.Fields " + "Missing [s3.auth.type] fields in the connector properties.Fields " + "[s3.uri] exceeds more than 1000 characters.", illegalArgumentException.getMessage()); } diff --git a/docs/user/ppl/admin/prometheus_connector.rst b/docs/user/ppl/admin/connectors/prometheus_connector.rst similarity index 100% rename from docs/user/ppl/admin/prometheus_connector.rst rename to docs/user/ppl/admin/connectors/prometheus_connector.rst diff --git a/docs/user/ppl/admin/connectors/s3glue_connector.rst b/docs/user/ppl/admin/connectors/s3glue_connector.rst new file mode 100644 index 0000000000..640eb90283 --- /dev/null +++ b/docs/user/ppl/admin/connectors/s3glue_connector.rst @@ -0,0 +1,68 @@ +.. highlight:: sh + +==================== +S3Glue Connector +==================== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +s3Glue connector provides a way to query s3 files using glue as metadata store and spark as execution engine. +This page covers s3Glue datasource configuration and also how to query and s3Glue datasource. + + +Required resources for s3 Glue Connector +=================================== +* S3: This is where the data lies. +* Spark Execution Engine: Query Execution happens on spark. +* Glue Metadata store: Glue takes care of table metadata. +* Opensearch: Index for s3 data lies in opensearch and also acts as temporary buffer for query results. + +We currently only support emr-serverless as spark execution engine and Glue as metadata store. we will add more support in future. + +Glue Connector Properties in DataSource Configuration +======================================================== +Glue Connector Properties. + +* ``glue.auth.type`` [Required] + * This parameters provides the authentication type information required for execution engine to connect to glue. + * S3 Glue connector currently only supports ``iam_role`` authentication and the below parameters is required. + * ``glue.auth.role_arn`` +* ``glue.indexstore.opensearch.*`` [Required] + * This parameters provides the Opensearch domain host information for glue connector. This opensearch instance is used for writing index data back and also + * ``glue.indexstore.opensearch.uri`` [Required] + * ``glue.indexstore.opensearch.auth`` [Required] + * Default value for auth is ``false``. + * ``glue.indexstore.opensearch.region`` [Required] + * Default value for auth is ``us-west-2``. + +Sample Glue dataSource configuration +======================================== + +Glue datasource configuration:: + + [{ + "name" : "my_glue", + "connector": "s3glue", + "properties" : { + "glue.auth.type": "iam_role", + "glue.auth.role_arn": "role_arn", + "glue.indexstore.opensearch.uri": "http://localhost:9200", + "glue.indexstore.opensearch.auth" :"false", + "glue.indexstore.opensearch.region": "us-west-2" + } + }] + + +Sample s3Glue datasource queries +================================ + + + diff --git a/docs/user/ppl/admin/spark_connector.rst b/docs/user/ppl/admin/connectors/spark_connector.rst similarity index 100% rename from docs/user/ppl/admin/spark_connector.rst rename to docs/user/ppl/admin/connectors/spark_connector.rst diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f20de87d61..4629c4e29f 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -51,6 +51,7 @@ import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelper; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; import org.opensearch.sql.datasources.model.transport.DeleteDataSourceActionResponse; import org.opensearch.sql.datasources.model.transport.GetDataSourceActionResponse; @@ -241,6 +242,7 @@ private DataSourceServiceImpl createDataSourceService() { new OpenSearchNodeClient(this.client), pluginSettings)) .add(new PrometheusStorageFactory(pluginSettings)) .add(new SparkStorageFactory(this.client, pluginSettings)) + .add(new GlueDataSourceFactory(pluginSettings)) .build(), dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactoryTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactoryTest.java index 41d439d120..f17a4b10d0 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactoryTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactoryTest.java @@ -81,7 +81,7 @@ void testGetStorageEngineWithMissingURI() { IllegalArgumentException.class, () -> prometheusStorageFactory.getStorageEngine(properties)); Assertions.assertEquals( - "Missing [prometheus.uri] fields " + "in the Prometheus connector properties.", + "Missing [prometheus.uri] fields " + "in the connector properties.", exception.getMessage()); } @@ -99,7 +99,7 @@ void testGetStorageEngineWithMissingRegionInAWS() { IllegalArgumentException.class, () -> prometheusStorageFactory.getStorageEngine(properties)); Assertions.assertEquals( - "Missing [prometheus.auth.region] fields in the " + "Prometheus connector properties.", + "Missing [prometheus.auth.region] fields in the connector properties.", exception.getMessage()); } @@ -118,7 +118,7 @@ void testGetStorageEngineWithLongConfigProperties() { () -> prometheusStorageFactory.getStorageEngine(properties)); Assertions.assertEquals( "Missing [prometheus.auth.region] fields in the " - + "Prometheus connector properties." + + "connector properties." + "Fields [prometheus.uri] exceeds more than 1000 characters.", exception.getMessage()); } From cda01e933fcdfe3454ae2c5c8981ae6bca087317 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 8 Sep 2023 17:06:11 -0700 Subject: [PATCH 2/5] Initial commit of new job APIs (#2050) Signed-off-by: Vamsi Manohar --- .../org/opensearch/sql/plugin/SQLPlugin.java | 27 +- spark/build.gradle | 5 +- .../spark/rest/RestJobManagementAction.java | 262 ++++++++++++++++++ .../spark/rest/model/CreateJobRequest.java | 35 +++ .../TransportCreateJobRequestAction.java | 39 +++ .../TransportDeleteJobRequestAction.java | 39 +++ .../TransportGetJobRequestAction.java | 52 ++++ .../TransportGetQueryResultRequestAction.java | 42 +++ .../model/CreateJobActionRequest.java | 34 +++ .../model/CreateJobActionResponse.java | 31 +++ .../model/DeleteJobActionRequest.java | 30 ++ .../model/DeleteJobActionResponse.java | 31 +++ .../transport/model/GetJobActionRequest.java | 33 +++ .../transport/model/GetJobActionResponse.java | 31 +++ .../model/GetJobQueryResultActionRequest.java | 31 +++ .../GetJobQueryResultActionResponse.java | 31 +++ .../TransportCreateJobRequestActionTest.java | 55 ++++ .../TransportDeleteJobRequestActionTest.java | 53 ++++ .../TransportGetJobRequestActionTest.java | 60 ++++ ...nsportGetQueryResultRequestActionTest.java | 54 ++++ .../org.mockito.plugins.MockMaker | 1 + 21 files changed, 973 insertions(+), 3 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java create mode 100644 spark/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 4629c4e29f..80e1a6b1a3 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -83,7 +83,16 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; +import org.opensearch.sql.spark.rest.RestJobManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; +import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; +import org.opensearch.sql.spark.transport.TransportDeleteJobRequestAction; +import org.opensearch.sql.spark.transport.TransportGetJobRequestAction; +import org.opensearch.sql.spark.transport.TransportGetQueryResultRequestAction; +import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; +import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; +import org.opensearch.sql.spark.transport.model.GetJobActionResponse; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -132,7 +141,8 @@ public List getRestHandlers( new RestSqlStatsAction(settings, restController), new RestPPLStatsAction(settings, restController), new RestQuerySettingsAction(settings, restController), - new RestDataSourceQueryAction()); + new RestDataSourceQueryAction(), + new RestJobManagementAction()); } /** Register action and handler so that transportClient can find proxy for action. */ @@ -156,7 +166,20 @@ public List getRestHandlers( new ActionHandler<>( new ActionType<>( TransportDeleteDataSourceAction.NAME, DeleteDataSourceActionResponse::new), - TransportDeleteDataSourceAction.class)); + TransportDeleteDataSourceAction.class), + new ActionHandler<>( + new ActionType<>(TransportCreateJobRequestAction.NAME, CreateJobActionResponse::new), + TransportCreateJobRequestAction.class), + new ActionHandler<>( + new ActionType<>(TransportGetJobRequestAction.NAME, GetJobActionResponse::new), + TransportGetJobRequestAction.class), + new ActionHandler<>( + new ActionType<>( + TransportGetQueryResultRequestAction.NAME, GetJobQueryResultActionResponse::new), + TransportGetQueryResultRequestAction.class), + new ActionHandler<>( + new ActionType<>(TransportDeleteJobRequestAction.NAME, DeleteJobActionResponse::new), + TransportDeleteJobRequestAction.class)); } @Override diff --git a/spark/build.gradle b/spark/build.gradle index 89842e5ea8..b93e3327ce 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -25,6 +25,7 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' testImplementation 'junit:junit:4.13.1' + testImplementation "org.opensearch.test:framework:${opensearch_version}" } test { @@ -53,7 +54,9 @@ jacocoTestCoverageVerification { rule { element = 'CLASS' excludes = [ - 'org.opensearch.sql.spark.data.constants.*' + 'org.opensearch.sql.spark.data.constants.*', + 'org.opensearch.sql.spark.rest.*', + 'org.opensearch.sql.spark.transport.model.*' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java new file mode 100644 index 0000000000..669cbb6aca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java @@ -0,0 +1,262 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.rest.RestRequest.Method.DELETE; +import static org.opensearch.rest.RestRequest.Method.GET; +import static org.opensearch.rest.RestRequest.Method.POST; + +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.sql.datasources.exceptions.ErrorMessage; +import org.opensearch.sql.datasources.utils.Scheduler; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; +import org.opensearch.sql.spark.transport.TransportDeleteJobRequestAction; +import org.opensearch.sql.spark.transport.TransportGetJobRequestAction; +import org.opensearch.sql.spark.transport.TransportGetQueryResultRequestAction; +import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; +import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; +import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; +import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; +import org.opensearch.sql.spark.transport.model.GetJobActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobActionResponse; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; + +public class RestJobManagementAction extends BaseRestHandler { + + public static final String JOB_ACTIONS = "job_actions"; + public static final String BASE_JOB_ACTION_URL = "/_plugins/_query/_jobs"; + + private static final Logger LOG = LogManager.getLogger(RestJobManagementAction.class); + + @Override + public String getName() { + return JOB_ACTIONS; + } + + @Override + public List routes() { + return ImmutableList.of( + + /* + * + * Create a new job with spark execution engine. + * Request URL: POST + * Request body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Response body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + */ + new Route(POST, BASE_JOB_ACTION_URL), + + /* + * + * GET jobs with in spark execution engine. + * Request URL: GET + * Request body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Response body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + */ + new Route(GET, String.format(Locale.ROOT, "%s/{%s}", BASE_JOB_ACTION_URL, "jobId")), + new Route(GET, BASE_JOB_ACTION_URL), + + /* + * + * Cancel a job within spark execution engine. + * Request URL: DELETE + * Request body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Response body: + * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + */ + new Route(DELETE, String.format(Locale.ROOT, "%s/{%s}", BASE_JOB_ACTION_URL, "jobId")), + + /* + * GET query result from job {{jobId}} execution. + * Request URL: GET + * Request body: + * Ref [org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest] + * Response body: + * Ref [org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse] + */ + new Route(GET, String.format(Locale.ROOT, "%s/{%s}/result", BASE_JOB_ACTION_URL, "jobId"))); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) + throws IOException { + switch (restRequest.method()) { + case POST: + return executePostRequest(restRequest, nodeClient); + case GET: + return executeGetRequest(restRequest, nodeClient); + case DELETE: + return executeDeleteRequest(restRequest, nodeClient); + default: + return restChannel -> + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.METHOD_NOT_ALLOWED, String.valueOf(restRequest.method()))); + } + } + + private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) + throws IOException { + CreateJobRequest submitJobRequest = + CreateJobRequest.fromXContentParser(restRequest.contentParser()); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportCreateJobRequestAction.ACTION_TYPE, + new CreateJobActionRequest(submitJobRequest), + new ActionListener<>() { + @Override + public void onResponse(CreateJobActionResponse createJobActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.CREATED, + "application/json; charset=UTF-8", + submitJobRequest.getQuery())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + + private RestChannelConsumer executeGetRequest(RestRequest restRequest, NodeClient nodeClient) { + Boolean isResultRequest = restRequest.rawPath().contains("result"); + if (isResultRequest) { + return executeGetJobQueryResultRequest(nodeClient, restRequest); + } else { + return executeGetJobRequest(nodeClient, restRequest); + } + } + + private RestChannelConsumer executeGetJobQueryResultRequest( + NodeClient nodeClient, RestRequest restRequest) { + String jobId = restRequest.param("jobId"); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportGetQueryResultRequestAction.ACTION_TYPE, + new GetJobQueryResultActionRequest(jobId), + new ActionListener<>() { + @Override + public void onResponse( + GetJobQueryResultActionResponse getJobQueryResultActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.OK, + "application/json; charset=UTF-8", + getJobQueryResultActionResponse.getResult())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + + private RestChannelConsumer executeGetJobRequest(NodeClient nodeClient, RestRequest restRequest) { + String jobId = restRequest.param("jobId"); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportGetJobRequestAction.ACTION_TYPE, + new GetJobActionRequest(jobId), + new ActionListener<>() { + @Override + public void onResponse(GetJobActionResponse getJobActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.OK, + "application/json; charset=UTF-8", + getJobActionResponse.getResult())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + + private void handleException(Exception e, RestChannel restChannel) { + if (e instanceof OpenSearchException) { + OpenSearchException exception = (OpenSearchException) e; + reportError(restChannel, exception, exception.status()); + } else { + LOG.error("Error happened during request handling", e); + if (isClientError(e)) { + reportError(restChannel, e, BAD_REQUEST); + } else { + reportError(restChannel, e, SERVICE_UNAVAILABLE); + } + } + } + + private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { + String jobId = restRequest.param("jobId"); + return restChannel -> + Scheduler.schedule( + nodeClient, + () -> + nodeClient.execute( + TransportDeleteJobRequestAction.ACTION_TYPE, + new DeleteJobActionRequest(jobId), + new ActionListener<>() { + @Override + public void onResponse(DeleteJobActionResponse deleteJobActionResponse) { + restChannel.sendResponse( + new BytesRestResponse( + RestStatus.OK, + "application/json; charset=UTF-8", + deleteJobActionResponse.getResult())); + } + + @Override + public void onFailure(Exception e) { + handleException(e, restChannel); + } + })); + } + + private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { + channel.sendResponse( + new BytesRestResponse(status, new ErrorMessage(e, status.getStatus()).toString())); + } + + private static boolean isClientError(Exception e) { + return e instanceof IllegalArgumentException || e instanceof IllegalStateException; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java new file mode 100644 index 0000000000..ef29e857c8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Data; +import org.opensearch.core.xcontent.XContentParser; + +@Data +@AllArgsConstructor +public class CreateJobRequest { + + private String query; + + public static CreateJobRequest fromXContentParser(XContentParser parser) throws IOException { + String query = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("query")) { + query = parser.textOrNull(); + } else { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + return new CreateJobRequest(query); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java new file mode 100644 index 0000000000..53ae9fad90 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java @@ -0,0 +1,39 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; +import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportCreateJobRequestAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/jobs/create"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, CreateJobActionResponse::new); + + @Inject + public TransportCreateJobRequestAction( + TransportService transportService, ActionFilters actionFilters) { + super(NAME, transportService, actionFilters, CreateJobActionRequest::new); + } + + @Override + protected void doExecute( + Task task, CreateJobActionRequest request, ActionListener listener) { + String responseContent = "submitted_job"; + listener.onResponse(new CreateJobActionResponse(responseContent)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java new file mode 100644 index 0000000000..dcccb76272 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java @@ -0,0 +1,39 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; +import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportDeleteJobRequestAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/jobs/delete"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, DeleteJobActionResponse::new); + + @Inject + public TransportDeleteJobRequestAction( + TransportService transportService, ActionFilters actionFilters) { + super(NAME, transportService, actionFilters, DeleteJobActionRequest::new); + } + + @Override + protected void doExecute( + Task task, DeleteJobActionRequest request, ActionListener listener) { + String responseContent = "deleted_job"; + listener.onResponse(new DeleteJobActionResponse(responseContent)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java new file mode 100644 index 0000000000..96e002bd81 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java @@ -0,0 +1,52 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.GetJobActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportGetJobRequestAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/jobs/read"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, GetJobActionResponse::new); + + @Inject + public TransportGetJobRequestAction( + TransportService transportService, ActionFilters actionFilters) { + super(NAME, transportService, actionFilters, GetJobActionRequest::new); + } + + @Override + protected void doExecute( + Task task, GetJobActionRequest request, ActionListener listener) { + String responseContent; + if (request.getJobId() == null) { + responseContent = handleGetAllJobs(); + } else { + responseContent = handleGetJob(request.getJobId()); + } + listener.onResponse(new GetJobActionResponse(responseContent)); + } + + private String handleGetAllJobs() { + return "All Jobs Information."; + } + + private String handleGetJob(String jobId) { + return String.format("Job %s details.", jobId); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java new file mode 100644 index 0000000000..6aba1b48b6 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java @@ -0,0 +1,42 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportGetQueryResultRequestAction + extends HandledTransportAction< + GetJobQueryResultActionRequest, GetJobQueryResultActionResponse> { + + public static final String NAME = "cluster:admin/opensearch/ql/jobs/result"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, GetJobQueryResultActionResponse::new); + + @Inject + public TransportGetQueryResultRequestAction( + TransportService transportService, ActionFilters actionFilters) { + super(NAME, transportService, actionFilters, GetJobQueryResultActionRequest::new); + } + + @Override + protected void doExecute( + Task task, + GetJobQueryResultActionRequest request, + ActionListener listener) { + String responseContent = "job result"; + listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java new file mode 100644 index 0000000000..cbdcb617af --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java @@ -0,0 +1,34 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; + +public class CreateJobActionRequest extends ActionRequest { + + @Getter private CreateJobRequest createJobRequest; + + /** Constructor of CreateJobActionRequest from StreamInput. */ + public CreateJobActionRequest(StreamInput in) throws IOException { + super(in); + } + + public CreateJobActionRequest(CreateJobRequest createJobRequest) { + this.createJobRequest = createJobRequest; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java new file mode 100644 index 0000000000..ce76d4a20d --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class CreateJobActionResponse extends ActionResponse { + + @Getter private final String result; + + public CreateJobActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java new file mode 100644 index 0000000000..eaf379047a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java @@ -0,0 +1,30 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.AllArgsConstructor; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +@AllArgsConstructor +public class DeleteJobActionRequest extends ActionRequest { + + private String jobId; + + /** Constructor of SubmitJobActionRequest from StreamInput. */ + public DeleteJobActionRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java new file mode 100644 index 0000000000..38be57c21d --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class DeleteJobActionResponse extends ActionResponse { + + @Getter private final String result; + + public DeleteJobActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java new file mode 100644 index 0000000000..f8969cde15 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java @@ -0,0 +1,33 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +@NoArgsConstructor +@AllArgsConstructor +public class GetJobActionRequest extends ActionRequest { + + @Getter private String jobId; + + /** Constructor of GetJobActionRequest from StreamInput. */ + public GetJobActionRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java new file mode 100644 index 0000000000..f904afdb4e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class GetJobActionResponse extends ActionResponse { + + @Getter private final String result; + + public GetJobActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java new file mode 100644 index 0000000000..1de7bae2c7 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +@AllArgsConstructor +public class GetJobQueryResultActionRequest extends ActionRequest { + + @Getter private String jobId; + + /** Constructor of GetJobQueryResultActionRequest from StreamInput. */ + public GetJobQueryResultActionRequest(StreamInput in) throws IOException { + super(in); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java new file mode 100644 index 0000000000..a7a8002c67 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java @@ -0,0 +1,31 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport.model; + +import java.io.IOException; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +@RequiredArgsConstructor +public class GetJobQueryResultActionResponse extends ActionResponse { + + @Getter private final String result; + + public GetJobQueryResultActionResponse(StreamInput in) throws IOException { + super(in); + result = in.readString(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(result); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java new file mode 100644 index 0000000000..4357899368 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java @@ -0,0 +1,55 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; +import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportCreateJobRequestActionTest { + + @Mock private TransportService transportService; + @Mock private TransportCreateJobRequestAction action; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportCreateJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + } + + @Test + public void testDoExecute() { + CreateJobRequest createJobRequest = new CreateJobRequest("source = my_glue.default.alb_logs"); + CreateJobActionRequest request = new CreateJobActionRequest(createJobRequest); + + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateJobActionResponse createJobActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals("submitted_job", createJobActionResponse.getResult()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java new file mode 100644 index 0000000000..828b264343 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java @@ -0,0 +1,53 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; +import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportDeleteJobRequestActionTest { + + @Mock private TransportService transportService; + @Mock private TransportDeleteJobRequestAction action; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor private ArgumentCaptor deleteJobActionResponseArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportDeleteJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + } + + @Test + public void testDoExecute() { + DeleteJobActionRequest request = new DeleteJobActionRequest("jobId"); + + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); + DeleteJobActionResponse deleteJobActionResponse = + deleteJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals("deleted_job", deleteJobActionResponse.getResult()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java new file mode 100644 index 0000000000..06d1ee8baf --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java @@ -0,0 +1,60 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.GetJobActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportGetJobRequestActionTest { + + @Mock private TransportService transportService; + @Mock private TransportGetJobRequestAction action; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor private ArgumentCaptor getJobActionResponseArgumentCaptor; + + @BeforeEach + public void setUp() { + action = new TransportGetJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + } + + @Test + public void testDoExecuteWithSingleJob() { + GetJobActionRequest request = new GetJobActionRequest("abcd"); + + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(getJobActionResponseArgumentCaptor.capture()); + GetJobActionResponse getJobActionResponse = getJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals("Job abcd details.", getJobActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithAllJobs() { + GetJobActionRequest request = new GetJobActionRequest(); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(getJobActionResponseArgumentCaptor.capture()); + GetJobActionResponse getJobActionResponse = getJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals("All Jobs Information.", getJobActionResponse.getResult()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java new file mode 100644 index 0000000000..f22adead49 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java @@ -0,0 +1,54 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportGetQueryResultRequestActionTest { + + @Mock private TransportService transportService; + @Mock private TransportGetQueryResultRequestAction action; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor + private ArgumentCaptor createJobActionResponseArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportGetQueryResultRequestAction( + transportService, new ActionFilters(new HashSet<>())); + } + + @Test + public void testDoExecuteForSingleJob() { + GetJobQueryResultActionRequest request = new GetJobQueryResultActionRequest("jobId"); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + GetJobQueryResultActionResponse getJobQueryResultActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals("job result", getJobQueryResultActionResponse.getResult()); + } +} diff --git a/spark/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/spark/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 0000000000..ca6ee9cea8 --- /dev/null +++ b/spark/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file From 71f4155c2a432a408d568e0c015374988479a4ae Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Thu, 21 Sep 2023 09:13:37 -0700 Subject: [PATCH 3/5] Create Job API (#2070) * Create Job API Signed-off-by: Vamsi Manohar * Refactor to Async Query API Signed-off-by: Vamsi Manohar --------- Signed-off-by: Vamsi Manohar --- common/build.gradle | 4 +- .../sql/common/setting/Settings.java | 2 +- .../sql/datasource/DataSourceService.java | 9 + .../sql/analysis/AnalyzerTestBase.java | 5 + .../service/DataSourceServiceImpl.java | 41 ++- .../service/DataSourceServiceImplTest.java | 30 ++- docs/user/interfaces/asyncqueryinterface.rst | 108 ++++++++ integ-test/build.gradle | 1 + .../sql/datasource/DataSourceAPIsIT.java | 4 +- .../setting/OpenSearchSettings.java | 13 + .../org/opensearch/sql/plugin/SQLPlugin.java | 93 +++++-- .../plugin-metadata/plugin-security.policy | 9 + spark/build.gradle | 9 +- .../asyncquery/AsyncQueryExecutorService.java | 32 +++ .../AsyncQueryExecutorServiceImpl.java | 107 ++++++++ .../AsyncQueryJobMetadataStorageService.java | 18 ++ ...chAsyncQueryJobMetadataStorageService.java | 171 ++++++++++++ .../AsyncQueryNotFoundException.java | 15 ++ .../model/AsyncQueryExecutionResponse.java | 21 ++ .../model/AsyncQueryJobMetadata.java | 100 +++++++ .../asyncquery/model/AsyncQueryResult.java | 29 +++ .../model/S3GlueSparkSubmitParameters.java | 97 +++++++ .../sql/spark/client/EmrClientImpl.java | 4 +- .../spark/client/EmrServerlessClientImpl.java | 68 +++++ .../sql/spark/client/SparkJobClient.java | 22 ++ .../config/SparkExecutionEngineConfig.java | 22 ++ .../spark/data/constants/SparkConstants.java | 55 +++- .../dispatcher/SparkQueryDispatcher.java | 103 ++++++++ ...DefaultSparkSqlFunctionResponseHandle.java | 4 +- .../response/JobExecutionResponseReader.java | 67 +++++ .../sql/spark/response/SparkResponse.java | 8 +- ...va => RestAsyncQueryManagementAction.java} | 143 ++++------ ...uest.java => CreateAsyncQueryRequest.java} | 11 +- .../rest/model/CreateAsyncQueryResponse.java | 15 ++ ...ransportCancelAsyncQueryRequestAction.java | 41 +++ ...ransportCreateAsyncQueryRequestAction.java | 64 +++++ .../TransportCreateJobRequestAction.java | 39 --- .../TransportDeleteJobRequestAction.java | 39 --- .../TransportGetAsyncQueryResultAction.java | 70 +++++ .../TransportGetJobRequestAction.java | 52 ---- .../TransportGetQueryResultRequestAction.java | 42 --- .../AsyncQueryResultResponseFormatter.java | 90 +++++++ ...ava => CancelAsyncQueryActionRequest.java} | 6 +- ...va => CancelAsyncQueryActionResponse.java} | 4 +- ...ava => CreateAsyncQueryActionRequest.java} | 12 +- ...va => CreateAsyncQueryActionResponse.java} | 4 +- .../model/DeleteJobActionResponse.java | 31 --- ... => GetAsyncQueryResultActionRequest.java} | 6 +- ...=> GetAsyncQueryResultActionResponse.java} | 4 +- .../transport/model/GetJobActionRequest.java | 33 --- .../resources/job-metadata-index-mapping.yml | 20 ++ .../resources/job-metadata-index-settings.yml | 11 + .../AsyncQueryExecutorServiceImplTest.java | 145 +++++++++++ ...yncQueryJobMetadataStorageServiceTest.java | 246 ++++++++++++++++++ .../client/EmrServerlessClientImplTest.java | 48 ++++ .../sql/spark/constants/TestConstants.java | 7 + .../dispatcher/SparkQueryDispatcherTest.java | 174 +++++++++++++ ...AsyncQueryExecutionResponseReaderTest.java | 78 ++++++ .../sql/spark/response/SparkResponseTest.java | 4 +- ...ortCancelAsyncQueryRequestActionTest.java} | 22 +- ...portCreateAsyncQueryRequestActionTest.java | 88 +++++++ .../TransportCreateJobRequestActionTest.java | 55 ---- ...ransportGetAsyncQueryResultActionTest.java | 139 ++++++++++ .../TransportGetJobRequestActionTest.java | 60 ----- ...nsportGetQueryResultRequestActionTest.java | 54 ---- ...AsyncQueryResultResponseFormatterTest.java | 40 +++ 66 files changed, 2574 insertions(+), 594 deletions(-) create mode 100644 docs/user/interfaces/asyncqueryinterface.rst create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/exceptions/AsyncQueryNotFoundException.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/S3GlueSparkSubmitParameters.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java rename spark/src/main/java/org/opensearch/sql/spark/rest/{RestJobManagementAction.java => RestAsyncQueryManagementAction.java} (50%) rename spark/src/main/java/org/opensearch/sql/spark/rest/model/{CreateJobRequest.java => CreateAsyncQueryRequest.java} (71%) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{DeleteJobActionRequest.java => CancelAsyncQueryActionRequest.java} (77%) rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{CreateJobActionResponse.java => CancelAsyncQueryActionResponse.java} (81%) rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{CreateJobActionRequest.java => CreateAsyncQueryActionRequest.java} (55%) rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{GetJobActionResponse.java => CreateAsyncQueryActionResponse.java} (81%) delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{GetJobQueryResultActionRequest.java => GetAsyncQueryResultActionRequest.java} (76%) rename spark/src/main/java/org/opensearch/sql/spark/transport/model/{GetJobQueryResultActionResponse.java => GetAsyncQueryResultActionResponse.java} (80%) delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java create mode 100644 spark/src/main/resources/job-metadata-index-mapping.yml create mode 100644 spark/src/main/resources/job-metadata-index-settings.yml create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java rename spark/src/test/java/org/opensearch/sql/spark/transport/{TransportDeleteJobRequestActionTest.java => TransportCancelAsyncQueryRequestActionTest.java} (57%) create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java diff --git a/common/build.gradle b/common/build.gradle index 5cf219fbae..109cad59cb 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -39,8 +39,8 @@ dependencies { api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' - api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1' - api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.545' implementation "com.github.seancfoley:ipaddress:5.4.0" testImplementation group: 'junit', name: 'junit', version: '4.13.2' diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index be780e8d80..8daf0e9bf6 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -35,7 +35,7 @@ public enum Key { METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), - + SPARK_EXECUTION_ENGINE_CONFIG("plugins.query.executionengine.spark.config"), CLUSTER_NAME("cluster.name"); @Getter private final String keyValue; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 3d6ddc864e..6dace50f99 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -39,6 +39,15 @@ public interface DataSourceService { */ DataSourceMetadata getDataSourceMetadata(String name); + /** + * Returns dataSourceMetadata object with specific name. The returned objects contain all the + * metadata information without any filtering. + * + * @param name name of the {@link DataSource}. + * @return set of {@link DataSourceMetadata}. + */ + DataSourceMetadata getRawDataSourceMetadata(String name); + /** * Register {@link DataSource} defined by {@link DataSourceMetadata}. * diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index f09bc5d380..a16d57673e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -208,6 +208,11 @@ public DataSourceMetadata getDataSourceMetadata(String name) { return null; } + @Override + public DataSourceMetadata getRawDataSourceMetadata(String name) { + return null; + } + @Override public void createDataSource(DataSourceMetadata metadata) { throw new UnsupportedOperationException("unsupported operation"); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 2ac480bbf2..d6c1907f84 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -64,29 +64,17 @@ public Set getDataSourceMetadata(boolean isDefaultDataSource } @Override - public DataSourceMetadata getDataSourceMetadata(String datasourceName) { - Optional dataSourceMetadataOptional = - getDataSourceMetadataFromName(datasourceName); - if (dataSourceMetadataOptional.isEmpty()) { - throw new IllegalArgumentException( - "DataSource with name: " + datasourceName + " doesn't exist."); - } - removeAuthInfo(dataSourceMetadataOptional.get()); - return dataSourceMetadataOptional.get(); + public DataSourceMetadata getDataSourceMetadata(String dataSourceName) { + DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); + removeAuthInfo(dataSourceMetadata); + return dataSourceMetadata; } @Override public DataSource getDataSource(String dataSourceName) { - Optional dataSourceMetadataOptional = - getDataSourceMetadataFromName(dataSourceName); - if (dataSourceMetadataOptional.isEmpty()) { - throw new DataSourceNotFoundException( - String.format("DataSource with name %s doesn't exist.", dataSourceName)); - } else { - DataSourceMetadata dataSourceMetadata = dataSourceMetadataOptional.get(); - this.dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - return dataSourceLoaderCache.getOrLoadDataSource(dataSourceMetadata); - } + DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); + this.dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); + return dataSourceLoaderCache.getOrLoadDataSource(dataSourceMetadata); } @Override @@ -146,11 +134,20 @@ private void validateDataSourceMetaData(DataSourceMetadata metadata) { + " Properties are required parameters."); } - private Optional getDataSourceMetadataFromName(String dataSourceName) { + @Override + public DataSourceMetadata getRawDataSourceMetadata(String dataSourceName) { if (dataSourceName.equals(DEFAULT_DATASOURCE_NAME)) { - return Optional.of(DataSourceMetadata.defaultOpenSearchDataSourceMetadata()); + return DataSourceMetadata.defaultOpenSearchDataSourceMetadata(); + } else { - return this.dataSourceMetadataStorage.getDataSourceMetadata(dataSourceName); + Optional dataSourceMetadataOptional = + this.dataSourceMetadataStorage.getDataSourceMetadata(dataSourceName); + if (dataSourceMetadataOptional.isEmpty()) { + throw new DataSourceNotFoundException( + String.format("DataSource with name %s doesn't exist.", dataSourceName)); + } else { + return dataSourceMetadataOptional.get(); + } } } diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 56d3586c6e..eb28495541 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -359,11 +359,11 @@ void testRemovalOfAuthorizationInfo() { @Test void testGetDataSourceMetadataForNonExistingDataSource() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")).thenReturn(Optional.empty()); - IllegalArgumentException exception = + DataSourceNotFoundException exception = assertThrows( - IllegalArgumentException.class, + DataSourceNotFoundException.class, () -> dataSourceService.getDataSourceMetadata("testDS")); - assertEquals("DataSource with name: testDS doesn't exist.", exception.getMessage()); + assertEquals("DataSource with name testDS doesn't exist.", exception.getMessage()); } @Test @@ -385,4 +385,28 @@ void testGetDataSourceMetadataForSpecificDataSourceName() { assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); verify(dataSourceMetadataStorage, times(1)).getDataSourceMetadata("testDS"); } + + @Test + void testGetRawDataSourceMetadata() { + HashMap properties = new HashMap<>(); + properties.put("prometheus.uri", "https://localhost:9090"); + properties.put("prometheus.auth.type", "basicauth"); + properties.put("prometheus.auth.username", "username"); + properties.put("prometheus.auth.password", "password"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testDS", + DataSourceType.PROMETHEUS, + Collections.singletonList("prometheus_access"), + properties); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getRawDataSourceMetadata("testDS"); + assertEquals("testDS", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.password")); + } } diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst new file mode 100644 index 0000000000..98990b795b --- /dev/null +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -0,0 +1,108 @@ +.. highlight:: sh + +======================= +Async Query Interface Endpoints +======================= + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +For supporting `S3Glue <../ppl/admin/connector/s3glue_connector.rst>`_ and Cloudwatch datasources connectors, we have introduced a new execution engine on top of Spark. +All the queries to be executed on spark execution engine can only be submitted via Async Query APIs. Below sections will list all the new APIs introduced. + + +Configuration required for Async Query APIs +====================================== +Currently, we only support AWS emr serverless as SPARK execution engine. The details of execution engine should be configured under +``plugins.query.executionengine.spark.config`` cluster setting. The value should be a stringified json comprising of ``applicationId``, ``executionRoleARN``,``region``. +Sample Setting Value :: + + plugins.query.executionengine.spark.config: '{"applicationId":"xxxxx", "executionRoleARN":"arn:aws:iam::***********:role/emr-job-execution-role","region":"eu-west-1"}' + + +If this setting is not configured during bootstrap, Async Query APIs will be disabled and it requires a cluster restart to enable them back again. +We make use of default aws credentials chain to make calls to the emr serverless application and also make sure the default credentials +have pass role permissions for emr-job-execution-role mentioned in the engine configuration. + + + +Async Query Creation API +====================================== +If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. + +HTTP URI: _plugins/_query/_async_query +HTTP VERB: POST + + + +Sample Request:: + + curl --location 'http://localhost:9200/_plugins/_async_query' \ + --header 'Content-Type: application/json' \ + --data '{ + "kind" : "sql", + "query" : "select * from my_glue.default.http_logs limit 10" + }' + +Sample Response:: + + { + "queryId": "00fd796ut1a7eg0q" + } + +Async Query Result API +====================================== +If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/result``. +Async Query Creation and Result Query permissions are orthogonal, so any user with result api permissions and queryId can query the corresponding query results irrespective of the user who created the async query. + + +HTTP URI: _plugins/_query/_async_query/{queryId} +HTTP VERB: GET + + +Sample Request BODY:: + + curl --location --request GET 'http://localhost:9200/_plugins/_async_query/00fd796ut1a7eg0q' \ + --header 'Content-Type: application/json' \ + --data '{ + "query" : "select * from default.http_logs limit 1" + }' + +Sample Response if the Query is in Progress :: + + {"status":"RUNNING"} + +Sample Response If the Query is successful :: + + { + "schema": [ + { + "name": "indexed_col_name", + "type": "string" + }, + { + "name": "data_type", + "type": "string" + }, + { + "name": "skip_type", + "type": "string" + } + ], + "datarows": [ + [ + "status", + "int", + "VALUE_SET" + ] + ], + "total": 1, + "size": 1 + } diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 0404900450..dc92f9ebb3 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -162,6 +162,7 @@ configurations.all { resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk7:1.5.31" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" + resolutionStrategy.force "com.amazonaws:aws-java-sdk-core:1.12.545" } configurations { diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 6a6b4e7ba3..0b69a459a1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -157,12 +157,12 @@ public void deleteDataSourceTest() { Assert.assertThrows( ResponseException.class, () -> client().performRequest(prometheusGetRequest)); Assert.assertEquals( - 400, prometheusGetResponseException.getResponse().getStatusLine().getStatusCode()); + 404, prometheusGetResponseException.getResponse().getStatusLine().getStatusCode()); String prometheusGetResponseString = getResponseBody(prometheusGetResponseException.getResponse()); JsonObject errorMessage = new Gson().fromJson(prometheusGetResponseString, JsonObject.class); Assert.assertEquals( - "DataSource with name: delete_prometheus doesn't exist.", + "DataSource with name delete_prometheus doesn't exist.", errorMessage.get("error").getAsJsonObject().get("details").getAsString()); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 48ceacaf10..76bda07607 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -129,6 +129,12 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_ENGINE_CONFIG = + Setting.simpleString( + Key.SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue(), + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -193,6 +199,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.DATASOURCES_URI_HOSTS_DENY_LIST, DATASOURCE_URI_HOSTS_DENY_LIST, new Updater(Key.DATASOURCES_URI_HOSTS_DENY_LIST)); + register( + settingBuilder, + clusterSettings, + Key.SPARK_EXECUTION_ENGINE_CONFIG, + SPARK_EXECUTION_ENGINE_CONFIG, + new Updater(Key.SPARK_EXECUTION_ENGINE_CONFIG)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -257,6 +269,7 @@ public static List> pluginSettings() { .add(METRICS_ROLLING_WINDOW_SETTING) .add(METRICS_ROLLING_INTERVAL_SETTING) .add(DATASOURCE_URI_HOSTS_DENY_LIST) + .add(SPARK_EXECUTION_ENGINE_CONFIG) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 80e1a6b1a3..ed10b1e3e6 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -5,10 +5,16 @@ package org.opensearch.sql.plugin; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -83,16 +89,23 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; -import org.opensearch.sql.spark.rest.RestJobManagementAction; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.client.EmrServerlessClientImpl; +import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; -import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; -import org.opensearch.sql.spark.transport.TransportDeleteJobRequestAction; -import org.opensearch.sql.spark.transport.TransportGetJobRequestAction; -import org.opensearch.sql.spark.transport.TransportGetQueryResultRequestAction; -import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; -import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; -import org.opensearch.sql.spark.transport.model.GetJobActionResponse; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; +import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; +import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; +import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -110,6 +123,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; + private AsyncQueryExecutorService asyncQueryExecutorService; private Injector injector; public String name() { @@ -142,7 +156,7 @@ public List getRestHandlers( new RestPPLStatsAction(settings, restController), new RestQuerySettingsAction(settings, restController), new RestDataSourceQueryAction(), - new RestJobManagementAction()); + new RestAsyncQueryManagementAction()); } /** Register action and handler so that transportClient can find proxy for action. */ @@ -168,18 +182,17 @@ public List getRestHandlers( TransportDeleteDataSourceAction.NAME, DeleteDataSourceActionResponse::new), TransportDeleteDataSourceAction.class), new ActionHandler<>( - new ActionType<>(TransportCreateJobRequestAction.NAME, CreateJobActionResponse::new), - TransportCreateJobRequestAction.class), - new ActionHandler<>( - new ActionType<>(TransportGetJobRequestAction.NAME, GetJobActionResponse::new), - TransportGetJobRequestAction.class), + new ActionType<>( + TransportCreateAsyncQueryRequestAction.NAME, CreateAsyncQueryActionResponse::new), + TransportCreateAsyncQueryRequestAction.class), new ActionHandler<>( new ActionType<>( - TransportGetQueryResultRequestAction.NAME, GetJobQueryResultActionResponse::new), - TransportGetQueryResultRequestAction.class), + TransportGetAsyncQueryResultAction.NAME, GetAsyncQueryResultActionResponse::new), + TransportGetAsyncQueryResultAction.class), new ActionHandler<>( - new ActionType<>(TransportDeleteJobRequestAction.NAME, DeleteJobActionResponse::new), - TransportDeleteJobRequestAction.class)); + new ActionType<>( + TransportCancelAsyncQueryRequestAction.NAME, CancelAsyncQueryActionResponse::new), + TransportCancelAsyncQueryRequestAction.class)); } @Override @@ -202,6 +215,16 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); + if (StringUtils.isEmpty(this.pluginSettings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG))) { + LOGGER.warn( + String.format( + "Async Query APIs are disabled as %s is not configured in cluster settings. " + + "Please configure and restart the domain to enable Async Query APIs", + SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); + this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); + } else { + this.asyncQueryExecutorService = createAsyncQueryExecutorService(); + } ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); @@ -213,7 +236,7 @@ public Collection createComponents( }); injector = modules.createInjector(); - return ImmutableList.of(dataSourceService); + return ImmutableList.of(dataSourceService, asyncQueryExecutorService); } @Override @@ -270,4 +293,34 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } + + private AsyncQueryExecutorService createAsyncQueryExecutorService() { + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + SparkJobClient sparkJobClient = createEMRServerlessClient(); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, this.dataSourceService, jobExecutionResponseReader); + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings); + } + + private SparkJobClient createEMRServerlessClient() { + String sparkExecutionEngineConfigString = + this.pluginSettings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString); + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(sparkExecutionEngineConfig.getRegion()) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index aec517aa84..fcf70c01f9 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -15,4 +15,13 @@ grant { // ml-commons client permission java.lang.RuntimePermission "setContextClassLoader"; + + // aws credentials + permission java.io.FilePermission "${user.home}${/}.aws${/}*", "read"; + + // Permissions for aws emr servless sdk + permission javax.management.MBeanServerPermission "createMBeanServer"; + permission javax.management.MBeanServerPermission "findMBeanServer"; + permission javax.management.MBeanPermission "com.amazonaws.metrics.*", "*"; + permission javax.management.MBeanTrustPermission "register"; }; diff --git a/spark/build.gradle b/spark/build.gradle index b93e3327ce..fb9a1e0e4b 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -15,11 +15,14 @@ repositories { dependencies { api project(':core') + implementation project(':protocol') implementation project(':datasources') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.json', name: 'json', version: '20230227' - implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' + api group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.545' + api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' + implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' @@ -56,7 +59,9 @@ jacocoTestCoverageVerification { excludes = [ 'org.opensearch.sql.spark.data.constants.*', 'org.opensearch.sql.spark.rest.*', - 'org.opensearch.sql.spark.transport.model.*' + 'org.opensearch.sql.spark.transport.model.*', + 'org.opensearch.sql.spark.asyncquery.model.*', + 'org.opensearch.sql.spark.asyncquery.exceptions.*' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java new file mode 100644 index 0000000000..df13daa2a2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; + +/** + * AsyncQueryExecutorService exposes functionality to create, get results and cancel an async query. + */ +public interface AsyncQueryExecutorService { + + /** + * Creates async query job based on the request and returns queryId in the response. + * + * @param createAsyncQueryRequest createAsyncQueryRequest. + * @return {@link CreateAsyncQueryResponse} + */ + CreateAsyncQueryResponse createAsyncQuery(CreateAsyncQueryRequest createAsyncQueryRequest); + + /** + * Returns async query response for a given queryId. + * + * @param queryId queryId. + * @return {@link AsyncQueryExecutionResponse} + */ + AsyncQueryExecutionResponse getAsyncQueryResults(String queryId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java new file mode 100644 index 0000000000..e5ed65920e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; + +/** AsyncQueryExecutorService implementation of {@link AsyncQueryExecutorService}. */ +@AllArgsConstructor +public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService { + private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; + private SparkQueryDispatcher sparkQueryDispatcher; + private Settings settings; + private Boolean isSparkJobExecutionEnabled; + + public AsyncQueryExecutorServiceImpl() { + this.isSparkJobExecutionEnabled = Boolean.FALSE; + } + + public AsyncQueryExecutorServiceImpl( + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, + SparkQueryDispatcher sparkQueryDispatcher, + Settings settings) { + this.isSparkJobExecutionEnabled = Boolean.TRUE; + this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService; + this.sparkQueryDispatcher = sparkQueryDispatcher; + this.settings = settings; + } + + @Override + public CreateAsyncQueryResponse createAsyncQuery( + CreateAsyncQueryRequest createAsyncQueryRequest) { + validateSparkExecutionEngineSettings(); + String sparkExecutionEngineConfigString = + settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfig.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigString)); + String jobId = + sparkQueryDispatcher.dispatch( + sparkExecutionEngineConfig.getApplicationId(), + createAsyncQueryRequest.getQuery(), + sparkExecutionEngineConfig.getExecutionRoleARN()); + asyncQueryJobMetadataStorageService.storeJobMetadata( + new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); + return new CreateAsyncQueryResponse(jobId); + } + + @Override + public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { + validateSparkExecutionEngineSettings(); + Optional jobMetadata = + asyncQueryJobMetadataStorageService.getJobMetadata(queryId); + if (jobMetadata.isPresent()) { + JSONObject jsonObject = + sparkQueryDispatcher.getQueryResponse( + jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId()); + if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) { + DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = + new DefaultSparkSqlFunctionResponseHandle(jsonObject); + List result = new ArrayList<>(); + while (sparkSqlFunctionResponseHandle.hasNext()) { + result.add(sparkSqlFunctionResponseHandle.next()); + } + return new AsyncQueryExecutionResponse( + JobRunState.SUCCESS.toString(), sparkSqlFunctionResponseHandle.schema(), result); + } else { + return new AsyncQueryExecutionResponse(jsonObject.getString("status"), null, null); + } + } + throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); + } + + private void validateSparkExecutionEngineSettings() { + if (!isSparkJobExecutionEnabled) { + throw new IllegalArgumentException( + String.format( + "Async Query APIs are disabled as %s is not configured in cluster settings. Please" + + " configure the setting and restart the domain to enable Async Query APIs", + SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java new file mode 100644 index 0000000000..4ce34458cd --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java @@ -0,0 +1,18 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery; + +import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +public interface AsyncQueryJobMetadataStorageService { + + void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata); + + Optional getJobMetadata(String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java new file mode 100644 index 0000000000..cee38d10f8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -0,0 +1,171 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +/** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ +public class OpensearchAsyncQueryJobMetadataStorageService + implements AsyncQueryJobMetadataStorageService { + + public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; + private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = + "job-metadata-index-mapping.yml"; + private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = + "job-metadata-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + private final Client client; + private final ClusterService clusterService; + + /** + * This class implements JobMetadataStorageService interface using OpenSearch as underlying + * storage. + * + * @param client opensearch NodeClient. + * @param clusterService ClusterService. + */ + public OpensearchAsyncQueryJobMetadataStorageService( + Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createJobMetadataIndex(); + } + IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); + indexRequest.id(asyncQueryJobMetadata.getJobId()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture indexResponseActionFuture; + IndexResponse indexResponse; + try (ThreadContext.StoredContext storedContext = + client.threadPool().getThreadContext().stashContext()) { + indexRequest.source(AsyncQueryJobMetadata.convertToXContent(asyncQueryJobMetadata)); + indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("JobMetadata : {} successfully created", asyncQueryJobMetadata.getJobId()); + } else { + throw new RuntimeException( + "Saving job metadata information failed with result : " + + indexResponse.getResult().getLowercase()); + } + } + + @Override + public Optional getJobMetadata(String jobId) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createJobMetadataIndex(); + return Optional.empty(); + } + return searchInJobMetadataIndex(QueryBuilders.termQuery("jobId", jobId)).stream().findFirst(); + } + + private void createJobMetadataIndex() { + try { + InputStream mappingFileStream = + OpensearchAsyncQueryJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpensearchAsyncQueryJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); + createIndexRequest + .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) + .settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + + JOB_METADATA_INDEX + + " index:: " + + e.getMessage()); + } + } + + private List searchInJobMetadataIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(JOB_METADATA_INDEX); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(1); + searchRequest.source(searchSourceBuilder); + // https://github.com/opensearch-project/sql/issues/1801. + searchRequest.preference("_primary_first"); + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching job metadata information failed with status : " + searchResponse.status()); + } else { + List list = new ArrayList<>(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + String sourceAsString = searchHit.getSourceAsString(); + AsyncQueryJobMetadata asyncQueryJobMetadata; + try { + asyncQueryJobMetadata = AsyncQueryJobMetadata.toJobMetadata(sourceAsString); + } catch (IOException e) { + throw new RuntimeException(e); + } + list.add(asyncQueryJobMetadata); + } + return list; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/exceptions/AsyncQueryNotFoundException.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/exceptions/AsyncQueryNotFoundException.java new file mode 100644 index 0000000000..80a0c34b70 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/exceptions/AsyncQueryNotFoundException.java @@ -0,0 +1,15 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery.exceptions; + +/** AsyncQueryNotFoundException. */ +public class AsyncQueryNotFoundException extends RuntimeException { + public AsyncQueryNotFoundException(String message) { + super(message); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java new file mode 100644 index 0000000000..84dcc490ba --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryExecutionResponse.java @@ -0,0 +1,21 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import java.util.List; +import lombok.Data; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; + +/** AsyncQueryExecutionResponse to store the response form spark job execution. */ +@Data +public class AsyncQueryExecutionResponse { + private final String status; + private final ExecutionEngine.Schema schema; + private final List results; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java new file mode 100644 index 0000000000..60ec53987e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -0,0 +1,100 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import com.google.gson.Gson; +import java.io.IOException; +import lombok.AllArgsConstructor; +import lombok.Data; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +/** This class models all the metadata required for a job. */ +@Data +@AllArgsConstructor +public class AsyncQueryJobMetadata { + private String jobId; + private String applicationId; + + @Override + public String toString() { + return new Gson().toJson(this); + } + + /** + * Converts JobMetadata to XContentBuilder. + * + * @param metadata metadata. + * @return XContentBuilder {@link XContentBuilder} + * @throws Exception Exception. + */ + public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("jobId", metadata.getJobId()); + builder.field("applicationId", metadata.getApplicationId()); + builder.endObject(); + return builder; + } + + /** + * Converts json string to DataSourceMetadata. + * + * @param json jsonstring. + * @return jobmetadata {@link AsyncQueryJobMetadata} + * @throws java.io.IOException IOException. + */ + public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toJobMetadata(parser); + } + } + + /** + * Convert xcontent parser to JobMetadata. + * + * @param parser parser. + * @return JobMetadata {@link AsyncQueryJobMetadata} + * @throws IOException IOException. + */ + public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { + String jobId = null; + String applicationId = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case "jobId": + jobId = parser.textOrNull(); + break; + case "applicationId": + applicationId = parser.textOrNull(); + break; + default: + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + if (jobId == null || applicationId == null) { + throw new IllegalArgumentException("jobId and applicationId are required fields."); + } + return new AsyncQueryJobMetadata(jobId, applicationId); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java new file mode 100644 index 0000000000..6d6bce8fbc --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java @@ -0,0 +1,29 @@ +package org.opensearch.sql.spark.asyncquery.model; + +import java.util.Collection; +import lombok.Getter; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.protocol.response.QueryResult; + +/** AsyncQueryResult for async query APIs. */ +public class AsyncQueryResult extends QueryResult { + + @Getter private final String status; + + public AsyncQueryResult( + String status, + ExecutionEngine.Schema schema, + Collection exprValues, + Cursor cursor) { + super(schema, exprValues, cursor); + this.status = status; + } + + public AsyncQueryResult( + String status, ExecutionEngine.Schema schema, Collection exprValues) { + super(schema, exprValues); + this.status = status; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/S3GlueSparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/S3GlueSparkSubmitParameters.java new file mode 100644 index 0000000000..fadb8a67a9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/S3GlueSparkSubmitParameters.java @@ -0,0 +1,97 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.AWS_SNAPSHOT_REPOSITORY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CATALOG_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_REGION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SQL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_CATALOG_HIVE_JAR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_HIVE_CATALOG_FACTORY_CLASS; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_CLASS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JAVA_HOME_LOCATION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.S3_AWS_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_PACKAGES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_REPOSITORIES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_EXTENSIONS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_STANDALONE_PACKAGE; + +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class S3GlueSparkSubmitParameters { + + private String className; + private Map config; + public static final String SPACE = " "; + public static final String EQUALS = "="; + + public S3GlueSparkSubmitParameters() { + this.className = DEFAULT_CLASS_NAME; + this.config = new LinkedHashMap<>(); + this.config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); + this.config.put( + HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, + DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + this.config.put(SPARK_JARS_KEY, GLUE_CATALOG_HIVE_JAR + "," + FLINT_CATALOG_JAR); + this.config.put(SPARK_JAR_PACKAGES_KEY, SPARK_STANDALONE_PACKAGE); + this.config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); + this.config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + this.config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + this.config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); + this.config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); + this.config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); + this.config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); + this.config.put(FLINT_INDEX_STORE_AWSREGION_KEY, FLINT_DEFAULT_REGION); + this.config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); + this.config.put(SPARK_SQL_EXTENSIONS_KEY, FLINT_SQL_EXTENSION); + this.config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + } + + public void addParameter(String key, String value) { + this.config.put(key, value); + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(" --class "); + stringBuilder.append(this.className); + stringBuilder.append(SPACE); + for (String key : config.keySet()) { + stringBuilder.append(" --conf "); + stringBuilder.append(key); + stringBuilder.append(EQUALS); + stringBuilder.append(config.get(key)); + stringBuilder.append(SPACE); + } + return stringBuilder.toString(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java index 1a3304994b..4e66cd9a00 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.client; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; @@ -74,7 +74,7 @@ void runEmrApplication(String query) { flint.getFlintIntegrationJar(), sparkApplicationJar, query, - SPARK_INDEX_NAME, + SPARK_RESPONSE_BUFFER_INDEX_NAME, flint.getFlintHost(), flint.getFlintPort(), flint.getFlintScheme(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java new file mode 100644 index 0000000000..b554c4cd23 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.GetJobRunRequest; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobDriver; +import com.amazonaws.services.emrserverless.model.SparkSubmit; +import com.amazonaws.services.emrserverless.model.StartJobRunRequest; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import java.security.AccessController; +import java.security.PrivilegedAction; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class EmrServerlessClientImpl implements SparkJobClient { + + private final AWSEMRServerless emrServerless; + private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class); + + public EmrServerlessClientImpl(AWSEMRServerless emrServerless) { + this.emrServerless = emrServerless; + } + + @Override + public String startJobRun( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams) { + StartJobRunRequest request = + new StartJobRunRequest() + .withName(jobName) + .withApplicationId(applicationId) + .withExecutionRoleArn(executionRoleArn) + .withJobDriver( + new JobDriver() + .withSparkSubmit( + new SparkSubmit() + .withEntryPoint(SPARK_SQL_APPLICATION_JAR) + .withEntryPointArguments(query, SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withSparkSubmitParameters(sparkSubmitParams))); + StartJobRunResult startJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.startJobRun(request)); + logger.info("Job Run ID: " + startJobRunResult.getJobRunId()); + return startJobRunResult.getJobRunId(); + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + GetJobRunRequest request = + new GetJobRunRequest().withApplicationId(applicationId).withJobRunId(jobId); + GetJobRunResult getJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) () -> emrServerless.getJobRun(request)); + logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); + return getJobRunResult; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java new file mode 100644 index 0000000000..ff9f4acedd --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java @@ -0,0 +1,22 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.client; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; + +public interface SparkJobClient { + + String startJobRun( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams); + + GetJobRunResult getJobRunResult(String applicationId, String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java new file mode 100644 index 0000000000..4f928c4f1f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.google.gson.Gson; +import lombok.Data; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class SparkExecutionEngineConfig { + private String applicationId; + private String region; + private String executionRoleARN; + + public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) { + return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 65d5a01ba2..21db8b9478 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -8,13 +8,64 @@ public class SparkConstants { public static final String EMR = "emr"; public static final String STEP_ID_FIELD = "stepId.keyword"; - public static final String SPARK_SQL_APPLICATION_JAR = "s3://spark-datasource/sql-job.jar"; - public static final String SPARK_INDEX_NAME = ".query_execution_result"; + // TODO should be replaced with mvn jar. + public static final String SPARK_SQL_APPLICATION_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/sql-job.jar"; + public static final String SPARK_RESPONSE_BUFFER_INDEX_NAME = ".query_execution_result"; + // TODO should be replaced with mvn jar. public static final String FLINT_INTEGRATION_JAR = "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + // TODO should be replaced with mvn jar. + public static final String GLUE_CATALOG_HIVE_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar"; + // TODO should be replaced with mvn jar. + public static final String FLINT_CATALOG_JAR = + "s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar"; public static final String FLINT_DEFAULT_HOST = "localhost"; public static final String FLINT_DEFAULT_PORT = "9200"; public static final String FLINT_DEFAULT_SCHEME = "http"; public static final String FLINT_DEFAULT_AUTH = "-1"; public static final String FLINT_DEFAULT_REGION = "us-west-2"; + public static final String DEFAULT_CLASS_NAME = "org.opensearch.sql.FlintJob"; + public static final String S3_AWS_CREDENTIALS_PROVIDER_KEY = + "spark.hadoop.fs.s3.customAWSCredentialsProvider"; + public static final String DRIVER_ENV_ASSUME_ROLE_ARN_KEY = + "spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN"; + public static final String EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY = + "spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN"; + public static final String HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY = + "spark.hadoop.aws.catalog.credentials.provider.factory.class"; + public static final String HIVE_METASTORE_GLUE_ARN_KEY = "spark.hive.metastore.glue.role.arn"; + public static final String SPARK_JARS_KEY = "spark.jars"; + public static final String SPARK_JAR_PACKAGES_KEY = "spark.jars.packages"; + public static final String SPARK_JAR_REPOSITORIES_KEY = "spark.jars.repositories"; + public static final String SPARK_DRIVER_ENV_JAVA_HOME_KEY = + "spark.emr-serverless.driverEnv.JAVA_HOME"; + public static final String SPARK_EXECUTOR_ENV_JAVA_HOME_KEY = "spark.executorEnv.JAVA_HOME"; + public static final String FLINT_INDEX_STORE_HOST_KEY = "spark.datasource.flint.host"; + public static final String FLINT_INDEX_STORE_PORT_KEY = "spark.datasource.flint.port"; + public static final String FLINT_INDEX_STORE_SCHEME_KEY = "spark.datasource.flint.scheme"; + public static final String FLINT_INDEX_STORE_AUTH_KEY = "spark.datasource.flint.auth"; + public static final String FLINT_INDEX_STORE_AWSREGION_KEY = "spark.datasource.flint.region"; + public static final String FLINT_CREDENTIALS_PROVIDER_KEY = + "spark.datasource.flint.customAWSCredentialsProvider"; + public static final String SPARK_SQL_EXTENSIONS_KEY = "spark.sql.extensions"; + public static final String HIVE_METASTORE_CLASS_KEY = + "spark.hadoop.hive.metastore.client.factory.class"; + public static final String DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE = + "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; + public static final String DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY = + "com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory"; + public static final String SPARK_STANDALONE_PACKAGE = + "org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT"; + public static final String AWS_SNAPSHOT_REPOSITORY = + "https://aws.oss.sonatype.org/content/repositories/snapshots"; + public static final String GLUE_HIVE_CATALOG_FACTORY_CLASS = + "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"; + public static final String FLINT_DELEGATE_CATALOG = "org.opensearch.sql.FlintDelegateCatalog"; + public static final String FLINT_SQL_EXTENSION = + "org.opensearch.flint.spark.FlintSparkExtensions"; + public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = + "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; + public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java new file mode 100644 index 0000000000..f632ceaf6a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.net.URI; +import java.net.URISyntaxException; +import lombok.AllArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.S3GlueSparkSubmitParameters; +import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +/** This class takes care of understanding query and dispatching job query to emr serverless. */ +@AllArgsConstructor +public class SparkQueryDispatcher { + + private SparkJobClient sparkJobClient; + + private DataSourceService dataSourceService; + + private JobExecutionResponseReader jobExecutionResponseReader; + + public String dispatch(String applicationId, String query, String executionRoleARN) { + String datasourceName = getDataSourceName(); + try { + return sparkJobClient.startJobRun( + query, + "flint-opensearch-query", + applicationId, + executionRoleARN, + constructSparkParameters(datasourceName)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } + } + + // TODO : Fetch from Result Index and then make call to EMR Serverless. + public JSONObject getQueryResponse(String applicationId, String queryId) { + GetJobRunResult getJobRunResult = sparkJobClient.getJobRunResult(applicationId, queryId); + JSONObject result = new JSONObject(); + if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { + result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId); + } + result.put("status", getJobRunResult.getJobRun().getState()); + return result; + } + + // TODO: Analyze given query + // Extract datasourceName + // Apply Authorizaiton. + private String getDataSourceName() { + return "my_glue"; + } + + // TODO: Analyze given query and get the role arn based on datasource type. + private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { + return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + } + + private String constructSparkParameters(String datasourceName) throws URISyntaxException { + DataSourceMetadata dataSourceMetadata = + dataSourceService.getRawDataSourceMetadata(datasourceName); + S3GlueSparkSubmitParameters s3GlueSparkSubmitParameters = new S3GlueSparkSubmitParameters(); + s3GlueSparkSubmitParameters.addParameter( + DRIVER_ENV_ASSUME_ROLE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + s3GlueSparkSubmitParameters.addParameter( + EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + s3GlueSparkSubmitParameters.addParameter( + HIVE_METASTORE_GLUE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); + String opensearchuri = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.uri"); + URI uri = new URI(opensearchuri); + String auth = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.auth"); + String region = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.region"); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); + s3GlueSparkSubmitParameters.addParameter( + FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_AUTH_KEY, auth); + s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_AWSREGION_KEY, region); + s3GlueSparkSubmitParameters.addParameter( + "spark.sql.catalog." + datasourceName, FLINT_DELEGATE_CATALOG); + return s3GlueSparkSubmitParameters.toString(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java index 823ad2da29..9d0cd59cf8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -15,7 +15,6 @@ import org.json.JSONObject; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; -import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; @@ -81,7 +80,8 @@ private static LinkedHashMap extractRow( } else if (type == ExprCoreType.FLOAT) { linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); } else if (type == ExprCoreType.DATE) { - linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); + // TODO :: correct this to ExprTimestampValue + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); } else if (type == ExprCoreType.TIMESTAMP) { linkedHashMap.put( column.getName(), new ExprTimestampValue(row.getString(column.getName()))); diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java new file mode 100644 index 0000000000..8abb7cd11f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/response/JobExecutionResponseReader.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class JobExecutionResponseReader { + private final Client client; + private static final Logger LOG = LogManager.getLogger(); + + /** + * JobExecutionResponseReader for spark query. + * + * @param client Opensearch client + */ + public JobExecutionResponseReader(Client client) { + this.client = client; + } + + public JSONObject getResultFromOpensearchIndex(String jobId) { + return searchInSparkIndex(QueryBuilders.termQuery(STEP_ID_FIELD, jobId)); + } + + private JSONObject searchInSparkIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + ActionFuture searchResponseActionFuture; + try { + searchResponseActionFuture = client.search(searchRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching result from " + + SPARK_RESPONSE_BUFFER_INDEX_NAME + + " index failed with status : " + + searchResponse.status()); + } else { + JSONObject data = new JSONObject(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + data.put("data", searchHit.getSourceAsMap()); + } + return data; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java index 3edb541384..496caba2c9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.response; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import com.google.common.annotations.VisibleForTesting; import lombok.Data; @@ -51,7 +51,7 @@ public JSONObject getResultFromOpensearchIndex() { private JSONObject searchInSparkIndex(QueryBuilder query) { SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(SPARK_INDEX_NAME); + searchRequest.indices(SPARK_RESPONSE_BUFFER_INDEX_NAME); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchRequest.source(searchSourceBuilder); @@ -65,7 +65,7 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { if (searchResponse.status().getStatus() != 200) { throw new RuntimeException( "Fetching result from " - + SPARK_INDEX_NAME + + SPARK_RESPONSE_BUFFER_INDEX_NAME + " index failed with status : " + searchResponse.status()); } else { @@ -80,7 +80,7 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { @VisibleForTesting void deleteInSparkIndex(String id) { - DeleteRequest deleteRequest = new DeleteRequest(SPARK_INDEX_NAME); + DeleteRequest deleteRequest = new DeleteRequest(SPARK_RESPONSE_BUFFER_INDEX_NAME); deleteRequest.id(id); ActionFuture deleteResponseActionFuture; try { diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java similarity index 50% rename from spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java rename to spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 669cbb6aca..56484688dc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestJobManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -27,30 +27,27 @@ import org.opensearch.rest.RestRequest; import org.opensearch.sql.datasources.exceptions.ErrorMessage; import org.opensearch.sql.datasources.utils.Scheduler; -import org.opensearch.sql.spark.rest.model.CreateJobRequest; -import org.opensearch.sql.spark.transport.TransportCreateJobRequestAction; -import org.opensearch.sql.spark.transport.TransportDeleteJobRequestAction; -import org.opensearch.sql.spark.transport.TransportGetJobRequestAction; -import org.opensearch.sql.spark.transport.TransportGetQueryResultRequestAction; -import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; -import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; -import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; -import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; -import org.opensearch.sql.spark.transport.model.GetJobActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobActionResponse; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; +import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; +import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; -public class RestJobManagementAction extends BaseRestHandler { +public class RestAsyncQueryManagementAction extends BaseRestHandler { - public static final String JOB_ACTIONS = "job_actions"; - public static final String BASE_JOB_ACTION_URL = "/_plugins/_query/_jobs"; + public static final String ASYNC_QUERY_ACTIONS = "async_query_actions"; + public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query"; - private static final Logger LOG = LogManager.getLogger(RestJobManagementAction.class); + private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class); @Override public String getName() { - return JOB_ACTIONS; + return ASYNC_QUERY_ACTIONS; } @Override @@ -59,47 +56,38 @@ public List routes() { /* * - * Create a new job with spark execution engine. + * Create a new async query using spark execution engine. * Request URL: POST * Request body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Ref [org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest] * Response body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + * Ref [org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse] */ - new Route(POST, BASE_JOB_ACTION_URL), + new Route(POST, BASE_ASYNC_QUERY_ACTION_URL), /* * - * GET jobs with in spark execution engine. + * GET Async Query result with in spark execution engine. * Request URL: GET * Request body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Ref [org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest] * Response body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + * Ref [org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse] */ - new Route(GET, String.format(Locale.ROOT, "%s/{%s}", BASE_JOB_ACTION_URL, "jobId")), - new Route(GET, BASE_JOB_ACTION_URL), + new Route( + GET, String.format(Locale.ROOT, "%s/{%s}", BASE_ASYNC_QUERY_ACTION_URL, "queryId")), /* * * Cancel a job within spark execution engine. * Request URL: DELETE * Request body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionRequest] + * Ref [org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest] * Response body: - * Ref [org.opensearch.sql.spark.transport.model.SubmitJobActionResponse] + * Ref [org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse] */ - new Route(DELETE, String.format(Locale.ROOT, "%s/{%s}", BASE_JOB_ACTION_URL, "jobId")), - - /* - * GET query result from job {{jobId}} execution. - * Request URL: GET - * Request body: - * Ref [org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest] - * Response body: - * Ref [org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse] - */ - new Route(GET, String.format(Locale.ROOT, "%s/{%s}/result", BASE_JOB_ACTION_URL, "jobId"))); + new Route( + DELETE, String.format(Locale.ROOT, "%s/{%s}", BASE_ASYNC_QUERY_ACTION_URL, "queryId"))); } @Override @@ -109,7 +97,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient case POST: return executePostRequest(restRequest, nodeClient); case GET: - return executeGetRequest(restRequest, nodeClient); + return executeGetAsyncQueryResultRequest(restRequest, nodeClient); case DELETE: return executeDeleteRequest(restRequest, nodeClient); default: @@ -122,23 +110,24 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) throws IOException { - CreateJobRequest submitJobRequest = - CreateJobRequest.fromXContentParser(restRequest.contentParser()); + CreateAsyncQueryRequest submitJobRequest = + CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( - TransportCreateJobRequestAction.ACTION_TYPE, - new CreateJobActionRequest(submitJobRequest), + TransportCreateAsyncQueryRequestAction.ACTION_TYPE, + new CreateAsyncQueryActionRequest(submitJobRequest), new ActionListener<>() { @Override - public void onResponse(CreateJobActionResponse createJobActionResponse) { + public void onResponse( + CreateAsyncQueryActionResponse createAsyncQueryActionResponse) { restChannel.sendResponse( new BytesRestResponse( RestStatus.CREATED, "application/json; charset=UTF-8", - submitJobRequest.getQuery())); + createAsyncQueryActionResponse.getResult())); } @Override @@ -148,60 +137,25 @@ public void onFailure(Exception e) { })); } - private RestChannelConsumer executeGetRequest(RestRequest restRequest, NodeClient nodeClient) { - Boolean isResultRequest = restRequest.rawPath().contains("result"); - if (isResultRequest) { - return executeGetJobQueryResultRequest(nodeClient, restRequest); - } else { - return executeGetJobRequest(nodeClient, restRequest); - } - } - - private RestChannelConsumer executeGetJobQueryResultRequest( - NodeClient nodeClient, RestRequest restRequest) { - String jobId = restRequest.param("jobId"); + private RestChannelConsumer executeGetAsyncQueryResultRequest( + RestRequest restRequest, NodeClient nodeClient) { + String queryId = restRequest.param("queryId"); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( - TransportGetQueryResultRequestAction.ACTION_TYPE, - new GetJobQueryResultActionRequest(jobId), + TransportGetAsyncQueryResultAction.ACTION_TYPE, + new GetAsyncQueryResultActionRequest(queryId), new ActionListener<>() { @Override public void onResponse( - GetJobQueryResultActionResponse getJobQueryResultActionResponse) { - restChannel.sendResponse( - new BytesRestResponse( - RestStatus.OK, - "application/json; charset=UTF-8", - getJobQueryResultActionResponse.getResult())); - } - - @Override - public void onFailure(Exception e) { - handleException(e, restChannel); - } - })); - } - - private RestChannelConsumer executeGetJobRequest(NodeClient nodeClient, RestRequest restRequest) { - String jobId = restRequest.param("jobId"); - return restChannel -> - Scheduler.schedule( - nodeClient, - () -> - nodeClient.execute( - TransportGetJobRequestAction.ACTION_TYPE, - new GetJobActionRequest(jobId), - new ActionListener<>() { - @Override - public void onResponse(GetJobActionResponse getJobActionResponse) { + GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse) { restChannel.sendResponse( new BytesRestResponse( RestStatus.OK, "application/json; charset=UTF-8", - getJobActionResponse.getResult())); + getAsyncQueryResultActionResponse.getResult())); } @Override @@ -226,22 +180,23 @@ private void handleException(Exception e, RestChannel restChannel) { } private RestChannelConsumer executeDeleteRequest(RestRequest restRequest, NodeClient nodeClient) { - String jobId = restRequest.param("jobId"); + String queryId = restRequest.param("queryId"); return restChannel -> Scheduler.schedule( nodeClient, () -> nodeClient.execute( - TransportDeleteJobRequestAction.ACTION_TYPE, - new DeleteJobActionRequest(jobId), + TransportCancelAsyncQueryRequestAction.ACTION_TYPE, + new CancelAsyncQueryActionRequest(queryId), new ActionListener<>() { @Override - public void onResponse(DeleteJobActionResponse deleteJobActionResponse) { + public void onResponse( + CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse) { restChannel.sendResponse( new BytesRestResponse( RestStatus.OK, "application/json; charset=UTF-8", - deleteJobActionResponse.getResult())); + cancelAsyncQueryActionResponse.getResult())); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java similarity index 71% rename from spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java rename to spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index ef29e857c8..1e46ae48d2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -14,22 +14,27 @@ @Data @AllArgsConstructor -public class CreateJobRequest { +public class CreateAsyncQueryRequest { private String query; + private String lang; - public static CreateJobRequest fromXContentParser(XContentParser parser) throws IOException { + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) + throws IOException { String query = null; + String lang = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); if (fieldName.equals("query")) { query = parser.textOrNull(); + } else if (fieldName.equals("kind")) { + lang = parser.textOrNull(); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - return new CreateJobRequest(query); + return new CreateAsyncQueryRequest(query, lang); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java new file mode 100644 index 0000000000..8cfe57c2a6 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class CreateAsyncQueryResponse { + private String queryId; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java new file mode 100644 index 0000000000..990dbccd0b --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -0,0 +1,41 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportCancelAsyncQueryRequestAction + extends HandledTransportAction { + + public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); + + @Inject + public TransportCancelAsyncQueryRequestAction( + TransportService transportService, ActionFilters actionFilters) { + super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); + } + + @Override + protected void doExecute( + Task task, + CancelAsyncQueryActionRequest request, + ActionListener listener) { + String responseContent = "deleted_job"; + listener.onResponse(new CancelAsyncQueryActionResponse(responseContent)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java new file mode 100644 index 0000000000..991eafdad9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -0,0 +1,64 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportCreateAsyncQueryRequestAction + extends HandledTransportAction { + + private final AsyncQueryExecutorService asyncQueryExecutorService; + + public static final String NAME = "cluster:admin/opensearch/ql/async_query/create"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, CreateAsyncQueryActionResponse::new); + + @Inject + public TransportCreateAsyncQueryRequestAction( + TransportService transportService, + ActionFilters actionFilters, + AsyncQueryExecutorServiceImpl jobManagementService) { + super(NAME, transportService, actionFilters, CreateAsyncQueryActionRequest::new); + this.asyncQueryExecutorService = jobManagementService; + } + + @Override + protected void doExecute( + Task task, + CreateAsyncQueryActionRequest request, + ActionListener listener) { + try { + CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); + CreateAsyncQueryResponse createAsyncQueryResponse = + asyncQueryExecutorService.createAsyncQuery(createAsyncQueryRequest); + String responseContent = + new JsonResponseFormatter(JsonResponseFormatter.Style.PRETTY) { + @Override + protected Object buildJsonObject(CreateAsyncQueryResponse response) { + return response; + } + }.format(createAsyncQueryResponse); + listener.onResponse(new CreateAsyncQueryActionResponse(responseContent)); + } catch (Exception e) { + listener.onFailure(e); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java deleted file mode 100644 index 53ae9fad90..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestAction.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import org.opensearch.action.ActionType; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; -import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -public class TransportCreateJobRequestAction - extends HandledTransportAction { - - public static final String NAME = "cluster:admin/opensearch/ql/jobs/create"; - public static final ActionType ACTION_TYPE = - new ActionType<>(NAME, CreateJobActionResponse::new); - - @Inject - public TransportCreateJobRequestAction( - TransportService transportService, ActionFilters actionFilters) { - super(NAME, transportService, actionFilters, CreateJobActionRequest::new); - } - - @Override - protected void doExecute( - Task task, CreateJobActionRequest request, ActionListener listener) { - String responseContent = "submitted_job"; - listener.onResponse(new CreateJobActionResponse(responseContent)); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java deleted file mode 100644 index dcccb76272..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestAction.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import org.opensearch.action.ActionType; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; -import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -public class TransportDeleteJobRequestAction - extends HandledTransportAction { - - public static final String NAME = "cluster:admin/opensearch/ql/jobs/delete"; - public static final ActionType ACTION_TYPE = - new ActionType<>(NAME, DeleteJobActionResponse::new); - - @Inject - public TransportDeleteJobRequestAction( - TransportService transportService, ActionFilters actionFilters) { - super(NAME, transportService, actionFilters, DeleteJobActionRequest::new); - } - - @Override - protected void doExecute( - Task task, DeleteJobActionRequest request, ActionListener listener) { - String responseContent = "deleted_job"; - listener.onResponse(new DeleteJobActionResponse(responseContent)); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java new file mode 100644 index 0000000000..c23706b184 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -0,0 +1,70 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.protocol.response.format.ResponseFormatter; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class TransportGetAsyncQueryResultAction + extends HandledTransportAction< + GetAsyncQueryResultActionRequest, GetAsyncQueryResultActionResponse> { + + private final AsyncQueryExecutorService asyncQueryExecutorService; + + public static final String NAME = "cluster:admin/opensearch/ql/async_query/result"; + public static final ActionType ACTION_TYPE = + new ActionType<>(NAME, GetAsyncQueryResultActionResponse::new); + + @Inject + public TransportGetAsyncQueryResultAction( + TransportService transportService, + ActionFilters actionFilters, + AsyncQueryExecutorServiceImpl jobManagementService) { + super(NAME, transportService, actionFilters, GetAsyncQueryResultActionRequest::new); + this.asyncQueryExecutorService = jobManagementService; + } + + @Override + protected void doExecute( + Task task, + GetAsyncQueryResultActionRequest request, + ActionListener listener) { + try { + String jobId = request.getQueryId(); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + asyncQueryExecutorService.getAsyncQueryResults(jobId); + ResponseFormatter formatter = + new AsyncQueryResultResponseFormatter(JsonResponseFormatter.Style.PRETTY); + String responseContent = + formatter.format( + new AsyncQueryResult( + asyncQueryExecutionResponse.getStatus(), + asyncQueryExecutionResponse.getSchema(), + asyncQueryExecutionResponse.getResults(), + Cursor.None)); + listener.onResponse(new GetAsyncQueryResultActionResponse(responseContent)); + } catch (Exception e) { + listener.onFailure(e); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java deleted file mode 100644 index 96e002bd81..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetJobRequestAction.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import org.opensearch.action.ActionType; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.GetJobActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -public class TransportGetJobRequestAction - extends HandledTransportAction { - - public static final String NAME = "cluster:admin/opensearch/ql/jobs/read"; - public static final ActionType ACTION_TYPE = - new ActionType<>(NAME, GetJobActionResponse::new); - - @Inject - public TransportGetJobRequestAction( - TransportService transportService, ActionFilters actionFilters) { - super(NAME, transportService, actionFilters, GetJobActionRequest::new); - } - - @Override - protected void doExecute( - Task task, GetJobActionRequest request, ActionListener listener) { - String responseContent; - if (request.getJobId() == null) { - responseContent = handleGetAllJobs(); - } else { - responseContent = handleGetJob(request.getJobId()); - } - listener.onResponse(new GetJobActionResponse(responseContent)); - } - - private String handleGetAllJobs() { - return "All Jobs Information."; - } - - private String handleGetJob(String jobId) { - return String.format("Job %s details.", jobId); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java deleted file mode 100644 index 6aba1b48b6..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestAction.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import org.opensearch.action.ActionType; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -public class TransportGetQueryResultRequestAction - extends HandledTransportAction< - GetJobQueryResultActionRequest, GetJobQueryResultActionResponse> { - - public static final String NAME = "cluster:admin/opensearch/ql/jobs/result"; - public static final ActionType ACTION_TYPE = - new ActionType<>(NAME, GetJobQueryResultActionResponse::new); - - @Inject - public TransportGetQueryResultRequestAction( - TransportService transportService, ActionFilters actionFilters) { - super(NAME, transportService, actionFilters, GetJobQueryResultActionRequest::new); - } - - @Override - protected void doExecute( - Task task, - GetJobQueryResultActionRequest request, - ActionListener listener) { - String responseContent = "job result"; - listener.onResponse(new GetJobQueryResultActionResponse(responseContent)); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java new file mode 100644 index 0000000000..c9eb5bbf59 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.transport.format; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.List; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.protocol.response.QueryResult; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; + +/** + * JSON response format with schema header and data rows. For example, + * + *
+ *  {
+ *      "schema": [
+ *          {
+ *              "name": "name",
+ *              "type": "string"
+ *          }
+ *      ],
+ *      "datarows": [
+ *          ["John"],
+ *          ["Smith"]
+ *      ],
+ *      "total": 2,
+ *      "size": 2
+ *  }
+ * 
+ */ +public class AsyncQueryResultResponseFormatter extends JsonResponseFormatter { + + public AsyncQueryResultResponseFormatter(Style style) { + super(style); + } + + @Override + public Object buildJsonObject(AsyncQueryResult response) { + JsonResponse.JsonResponseBuilder json = JsonResponse.builder(); + if (response.getStatus().equalsIgnoreCase("success")) { + json.total(response.size()).size(response.size()); + json.schema( + response.columnNameTypes().entrySet().stream() + .map((entry) -> new Column(entry.getKey(), entry.getValue())) + .collect(Collectors.toList())); + json.datarows(fetchDataRows(response)); + } + json.status(response.getStatus()); + return json.build(); + } + + private Object[][] fetchDataRows(QueryResult response) { + Object[][] rows = new Object[response.size()][]; + int i = 0; + for (Object[] values : response) { + rows[i++] = values; + } + return rows; + } + + /** org.json requires these inner data classes be public (and static) */ + @Builder + @Getter + @JsonIgnoreProperties(ignoreUnknown = true) + public static class JsonResponse { + + private final String status; + + private final List schema; + + private final Object[][] datarows; + + private Integer total; + private Integer size; + } + + @RequiredArgsConstructor + @Getter + public static class Column { + private final String name; + private final String type; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java similarity index 77% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index eaf379047a..e12f184efe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -14,12 +14,12 @@ import org.opensearch.core.common.io.stream.StreamInput; @AllArgsConstructor -public class DeleteJobActionRequest extends ActionRequest { +public class CancelAsyncQueryActionRequest extends ActionRequest { - private String jobId; + private String queryId; /** Constructor of SubmitJobActionRequest from StreamInput. */ - public DeleteJobActionRequest(StreamInput in) throws IOException { + public CancelAsyncQueryActionRequest(StreamInput in) throws IOException { super(in); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionResponse.java similarity index 81% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionResponse.java index ce76d4a20d..af97140b49 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionResponse.java @@ -15,11 +15,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; @RequiredArgsConstructor -public class CreateJobActionResponse extends ActionResponse { +public class CancelAsyncQueryActionResponse extends ActionResponse { @Getter private final String result; - public CreateJobActionResponse(StreamInput in) throws IOException { + public CancelAsyncQueryActionResponse(StreamInput in) throws IOException { super(in); result = in.readString(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionRequest.java similarity index 55% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionRequest.java index cbdcb617af..bcb329b2dc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateJobActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionRequest.java @@ -12,19 +12,19 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.sql.spark.rest.model.CreateJobRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; -public class CreateJobActionRequest extends ActionRequest { +public class CreateAsyncQueryActionRequest extends ActionRequest { - @Getter private CreateJobRequest createJobRequest; + @Getter private CreateAsyncQueryRequest createAsyncQueryRequest; /** Constructor of CreateJobActionRequest from StreamInput. */ - public CreateJobActionRequest(StreamInput in) throws IOException { + public CreateAsyncQueryActionRequest(StreamInput in) throws IOException { super(in); } - public CreateJobActionRequest(CreateJobRequest createJobRequest) { - this.createJobRequest = createJobRequest; + public CreateAsyncQueryActionRequest(CreateAsyncQueryRequest createAsyncQueryRequest) { + this.createAsyncQueryRequest = createAsyncQueryRequest; } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionResponse.java similarity index 81% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionResponse.java index f904afdb4e..de5acc2537 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CreateAsyncQueryActionResponse.java @@ -15,11 +15,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; @RequiredArgsConstructor -public class GetJobActionResponse extends ActionResponse { +public class CreateAsyncQueryActionResponse extends ActionResponse { @Getter private final String result; - public GetJobActionResponse(StreamInput in) throws IOException { + public CreateAsyncQueryActionResponse(StreamInput in) throws IOException { super(in); result = in.readString(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java deleted file mode 100644 index 38be57c21d..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/DeleteJobActionResponse.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport.model; - -import java.io.IOException; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; - -@RequiredArgsConstructor -public class DeleteJobActionResponse extends ActionResponse { - - @Getter private final String result; - - public DeleteJobActionResponse(StreamInput in) throws IOException { - super(in); - result = in.readString(); - } - - @Override - public void writeTo(StreamOutput streamOutput) throws IOException { - streamOutput.writeString(result); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java similarity index 76% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java index 1de7bae2c7..06faa75a26 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionRequest.java @@ -15,12 +15,12 @@ import org.opensearch.core.common.io.stream.StreamInput; @AllArgsConstructor -public class GetJobQueryResultActionRequest extends ActionRequest { +public class GetAsyncQueryResultActionRequest extends ActionRequest { - @Getter private String jobId; + @Getter private String queryId; /** Constructor of GetJobQueryResultActionRequest from StreamInput. */ - public GetJobQueryResultActionRequest(StreamInput in) throws IOException { + public GetAsyncQueryResultActionRequest(StreamInput in) throws IOException { super(in); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionResponse.java similarity index 80% rename from spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java rename to spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionResponse.java index a7a8002c67..bb77bb131a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobQueryResultActionResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetAsyncQueryResultActionResponse.java @@ -15,11 +15,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; @RequiredArgsConstructor -public class GetJobQueryResultActionResponse extends ActionResponse { +public class GetAsyncQueryResultActionResponse extends ActionResponse { @Getter private final String result; - public GetJobQueryResultActionResponse(StreamInput in) throws IOException { + public GetAsyncQueryResultActionResponse(StreamInput in) throws IOException { super(in); result = in.readString(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java deleted file mode 100644 index f8969cde15..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/GetJobActionRequest.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport.model; - -import java.io.IOException; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; - -@NoArgsConstructor -@AllArgsConstructor -public class GetJobActionRequest extends ActionRequest { - - @Getter private String jobId; - - /** Constructor of GetJobActionRequest from StreamInput. */ - public GetJobActionRequest(StreamInput in) throws IOException { - super(in); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } -} diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml new file mode 100644 index 0000000000..ec2c83a4df --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -0,0 +1,20 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + jobId: + type: text + fields: + keyword: + type: keyword + applicationId: + type: text + fields: + keyword: + type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml new file mode 100644 index 0000000000..be93f4645c --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java new file mode 100644 index 0000000000..cf04278892 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.io.IOException; +import java.util.HashMap; +import java.util.Optional; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; + +@ExtendWith(MockitoExtension.class) +public class AsyncQueryExecutorServiceImplTest { + + @Mock private SparkQueryDispatcher sparkQueryDispatcher; + @Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; + @Mock private Settings settings; + + @Test + void testCreateAsyncQuery() { + AsyncQueryExecutorServiceImpl jobExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", "sql"); + when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + .thenReturn( + "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); + when(sparkQueryDispatcher.dispatch( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + "arn:aws:iam::270824043731:role/emr-job-execution-role")) + .thenReturn(EMR_JOB_ID); + CreateAsyncQueryResponse createAsyncQueryResponse = + jobExecutorService.createAsyncQuery(createAsyncQueryRequest); + verify(asyncQueryJobMetadataStorageService, times(1)) + .storeJobMetadata(new AsyncQueryJobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p")); + verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + verify(sparkQueryDispatcher, times(1)) + .dispatch( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + "arn:aws:iam::270824043731:role/emr-job-execution-role"); + Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); + } + + @Test + void testGetAsyncQueryResultsWithJobNotFoundException() { + AsyncQueryExecutorServiceImpl jobExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID)); + Assertions.assertEquals( + "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + verifyNoInteractions(sparkQueryDispatcher); + verifyNoInteractions(settings); + } + + @Test + void testGetAsyncQueryResultsWithInProgressJob() { + AsyncQueryExecutorServiceImpl jobExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + JSONObject jobResult = new JSONObject(); + jobResult.put("status", JobRunState.PENDING.toString()); + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(jobResult); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); + + Assertions.assertNull(asyncQueryExecutionResponse.getResults()); + Assertions.assertNull(asyncQueryExecutionResponse.getSchema()); + Assertions.assertEquals("PENDING", asyncQueryExecutionResponse.getStatus()); + verifyNoInteractions(settings); + } + + @Test + void testGetAsyncQueryResultsWithSuccessJob() throws IOException { + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); + jobResult.put("status", JobRunState.SUCCESS.toString()); + when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(jobResult); + + AsyncQueryExecutorServiceImpl jobExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); + + Assertions.assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); + Assertions.assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size()); + Assertions.assertEquals( + "1", asyncQueryExecutionResponse.getSchema().getColumns().get(0).getName()); + Assertions.assertEquals( + 1, + ((HashMap) asyncQueryExecutionResponse.getResults().get(0).value()) + .get("1")); + verifyNoInteractions(settings); + } + + @Test + void testGetAsyncQueryResultsWithDisabledExecutionEngine() { + AsyncQueryExecutorService asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> asyncQueryExecutorService.getAsyncQueryResults(EMR_JOB_ID)); + Assertions.assertEquals( + "Async Query APIs are disabled as plugins.query.executionengine.spark.config is not" + + " configured in cluster settings. Please configure the setting and restart the domain" + + " to enable Async Query APIs", + illegalArgumentException.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java new file mode 100644 index 0000000000..fe9da12ef0 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery; + +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService.JOB_METADATA_INDEX; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; + +import java.util.Optional; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.ArgumentMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +@ExtendWith(MockitoExtension.class) +public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private SearchResponse searchResponse; + + @Mock private ActionFuture searchResponseActionFuture; + @Mock private ActionFuture createIndexResponseActionFuture; + @Mock private ActionFuture indexResponseActionFuture; + @Mock private IndexResponse indexResponse; + @Mock private SearchHit searchHit; + + @InjectMocks + private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; + + @Test + public void testStoreJobMetadata() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + + this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithOutCreatingIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.TRUE); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + + this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithException() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())) + .thenThrow(new RuntimeException("error while indexing")); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithIndexCreationFailed() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "Internal server error while creating.ql-job-metadata index:: " + + "Index creation is not acknowledged.", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataFailedWithNotFoundResponse() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "Saving job metadata information failed with result : not_found", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testGetJobMetadata() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); + + Optional jobMetadataOptional = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + Assertions.assertTrue(jobMetadataOptional.isPresent()); + Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); + Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + } + + @Test + public void testGetJobMetadataWith404SearchResponse() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); + + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + Assertions.assertEquals( + "Fetching job metadata information failed with status : NOT_FOUND", + runtimeException.getMessage()); + } + + @Test + public void testGetJobMetadataWithParsingFailed() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); + + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + } + + @Test + public void testGetJobMetadataWithNoIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + + Optional jobMetadata = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + + Assertions.assertFalse(jobMetadata.isPresent()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java new file mode 100644 index 0000000000..36f10cd08b --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -0,0 +1,48 @@ +/* Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; + +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class EmrServerlessClientImplTest { + @Mock private AWSEMRServerless emrServerless; + + @Test + void testStartJobRun() { + StartJobRunResult response = new StartJobRunResult(); + when(emrServerless.startJobRun(any())).thenReturn(response); + + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + emrServerlessClient.startJobRun( + QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS); + } + + @Test + void testGetJobRunState() { + JobRun jobRun = new JobRun(); + jobRun.setState("Running"); + GetJobRunResult response = new GetJobRunResult(); + response.setJobRun(jobRun); + when(emrServerless.getJobRun(any())).thenReturn(response); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 2b1020568a..e455e6a049 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -7,5 +7,12 @@ public class TestConstants { public static final String QUERY = "select 1"; + public static final String TEST_DATASOURCE_NAME = "test_datasource_name"; public static final String EMR_CLUSTER_ID = "j-123456789"; + public static final String EMR_JOB_ID = "job-123xxx"; + public static final String EMRS_APPLICATION_ID = "app-xxxxx"; + public static final String EMRS_EXECUTION_ROLE = "execution_role"; + public static final String EMRS_DATASOURCE_ROLE = "datasource_role"; + public static final String EMRS_JOB_NAME = "job_name"; + public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java new file mode 100644 index 0000000000..800bd59b72 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import com.amazonaws.services.emrserverless.model.JobRun; +import com.amazonaws.services.emrserverless.model.JobRunState; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@ExtendWith(MockitoExtension.class) +public class SparkQueryDispatcherTest { + + @Mock private SparkJobClient sparkJobClient; + @Mock private DataSourceService dataSourceService; + @Mock private JobExecutionResponseReader jobExecutionResponseReader; + + @Test + void testDispatch() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.startJobRun( + QUERY, + "flint-opensearch-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString())) + .thenReturn(EMR_JOB_ID); + when(dataSourceService.getRawDataSourceMetadata("my_glue")) + .thenReturn(constructMyGlueDataSourceMetadata()); + String jobId = sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE); + verify(sparkJobClient, times(1)) + .startJobRun( + QUERY, + "flint-opensearch-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString()); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + + @Test + void testDispatchWithWrongURI() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(dataSourceService.getRawDataSourceMetadata("my_glue")) + .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE)); + Assertions.assertEquals( + "Bad URI in indexstore configuration of the : my_glue datasoure.", + illegalArgumentException.getMessage()); + } + + private DataSourceMetadata constructMyGlueDataSourceMetadata() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put( + "glue.indexstore.opensearch.uri", + "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + @Test + void testGetQueryResponse() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); + JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals("PENDING", result.get("status")); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testGetQueryResponseWithSuccess() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); + JSONObject queryResult = new JSONObject(); + queryResult.put("data", "result"); + when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) + .thenReturn(queryResult); + JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(sparkJobClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); + Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); + Assertions.assertEquals("result", result.get("data")); + Assertions.assertEquals("SUCCESS", result.get("status")); + } + + String constructExpectedSparkSubmitParameterString() { + return " --class org.opensearch.sql.FlintJob --conf" + + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf" + + " spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" + + " --conf" + + " spark.jars=s3://flint-data-dp-eu-west-1-beta/code/flint/AWSGlueDataCatalogHiveMetaStoreAuth-1.0.jar,s3://flint-data-dp-eu-west-1-beta/code/flint/flint-catalog.jar" + + " --conf" + + " spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.1.0-SNAPSHOT" + + " --conf" + + " spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots" + + " --conf" + + " spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" + + " --conf" + + " spark.datasource.flint.host=search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com" + + " --conf spark.datasource.flint.port=-1 --conf" + + " spark.datasource.flint.scheme=https --conf spark.datasource.flint.auth=sigv4 " + + " --conf spark.datasource.flint.region=eu-west-1 --conf" + + " spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + + " --conf spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions " + + " --conf" + + " spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory" + + " --conf" + + " spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf" + + " spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf" + + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java new file mode 100644 index 0000000000..17305fb905 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/response/AsyncQueryExecutionResponseReaderTest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; + +import java.util.Map; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; + +@ExtendWith(MockitoExtension.class) +public class AsyncQueryExecutionResponseReaderTest { + @Mock private Client client; + @Mock private SearchResponse searchResponse; + @Mock private SearchHit searchHit; + @Mock private ActionFuture searchResponseActionFuture; + + @Test + public void testGetResultFromOpensearchIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID)); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertFalse(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID).isEmpty()); + } + + @Test + public void testInvalidSearchResponse() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); + + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + Assertions.assertEquals( + "Fetching result from " + + SPARK_RESPONSE_BUFFER_INDEX_NAME + + " index failed with status : " + + RestStatus.NO_CONTENT, + exception.getMessage()); + } + + @Test + public void testSearchFailure() { + when(client.search(any())).thenThrow(RuntimeException.class); + JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); + assertThrows( + RuntimeException.class, + () -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java index 211561ac72..e234454021 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java @@ -10,7 +10,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_RESPONSE_BUFFER_INDEX_NAME; import java.util.Map; import org.apache.lucene.search.TotalHits; @@ -69,7 +69,7 @@ public void testInvalidSearchResponse() { assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); Assertions.assertEquals( "Fetching result from " - + SPARK_INDEX_NAME + + SPARK_RESPONSE_BUFFER_INDEX_NAME + " index failed with status : " + RestStatus.NO_CONTENT, exception.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java similarity index 57% rename from spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java rename to spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 828b264343..c560c882c0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportDeleteJobRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -19,35 +19,37 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.DeleteJobActionRequest; -import org.opensearch.sql.spark.transport.model.DeleteJobActionResponse; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @ExtendWith(MockitoExtension.class) -public class TransportDeleteJobRequestActionTest { +public class TransportCancelAsyncQueryRequestActionTest { @Mock private TransportService transportService; - @Mock private TransportDeleteJobRequestAction action; + @Mock private TransportCancelAsyncQueryRequestAction action; @Mock private Task task; - @Mock private ActionListener actionListener; + @Mock private ActionListener actionListener; - @Captor private ArgumentCaptor deleteJobActionResponseArgumentCaptor; + @Captor + private ArgumentCaptor deleteJobActionResponseArgumentCaptor; @BeforeEach public void setUp() { action = - new TransportDeleteJobRequestAction(transportService, new ActionFilters(new HashSet<>())); + new TransportCancelAsyncQueryRequestAction( + transportService, new ActionFilters(new HashSet<>())); } @Test public void testDoExecute() { - DeleteJobActionRequest request = new DeleteJobActionRequest("jobId"); + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest("jobId"); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); - DeleteJobActionResponse deleteJobActionResponse = + CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("deleted_job", deleteJobActionResponse.getResult()); + Assertions.assertEquals("deleted_job", cancelAsyncQueryActionResponse.getResult()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java new file mode 100644 index 0000000000..6596a9e820 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -0,0 +1,88 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; +import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportCreateAsyncQueryRequestActionTest { + + @Mock private TransportService transportService; + @Mock private TransportCreateAsyncQueryRequestAction action; + @Mock private AsyncQueryExecutorServiceImpl jobExecutorService; + @Mock private Task task; + @Mock private ActionListener actionListener; + + @Captor + private ArgumentCaptor createJobActionResponseArgumentCaptor; + + @Captor private ArgumentCaptor exceptionArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportCreateAsyncQueryRequestAction( + transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + } + + @Test + public void testDoExecute() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) + .thenReturn(new CreateAsyncQueryResponse("123")); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + CreateAsyncQueryActionResponse createAsyncQueryActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"queryId\": \"123\"\n" + "}", createAsyncQueryActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + CreateAsyncQueryActionRequest request = + new CreateAsyncQueryActionRequest(createAsyncQueryRequest); + doThrow(new RuntimeException("Error")) + .when(jobExecutorService) + .createAsyncQuery(createAsyncQueryRequest); + action.doExecute(task, request, actionListener); + verify(jobExecutorService, times(1)).createAsyncQuery(createAsyncQueryRequest); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java deleted file mode 100644 index 4357899368..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateJobRequestActionTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import java.util.HashSet; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.rest.model.CreateJobRequest; -import org.opensearch.sql.spark.transport.model.CreateJobActionRequest; -import org.opensearch.sql.spark.transport.model.CreateJobActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -@ExtendWith(MockitoExtension.class) -public class TransportCreateJobRequestActionTest { - - @Mock private TransportService transportService; - @Mock private TransportCreateJobRequestAction action; - @Mock private Task task; - @Mock private ActionListener actionListener; - - @Captor private ArgumentCaptor createJobActionResponseArgumentCaptor; - - @BeforeEach - public void setUp() { - action = - new TransportCreateJobRequestAction(transportService, new ActionFilters(new HashSet<>())); - } - - @Test - public void testDoExecute() { - CreateJobRequest createJobRequest = new CreateJobRequest("source = my_glue.default.alb_logs"); - CreateJobActionRequest request = new CreateJobActionRequest(createJobRequest); - - action.doExecute(task, request, actionListener); - Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); - CreateJobActionResponse createJobActionResponse = - createJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("submitted_job", createJobActionResponse.getResult()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java new file mode 100644 index 0000000000..9e4cd75165 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -0,0 +1,139 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.transport; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.HashSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; +import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +@ExtendWith(MockitoExtension.class) +public class TransportGetAsyncQueryResultActionTest { + + @Mock private TransportService transportService; + @Mock private TransportGetAsyncQueryResultAction action; + @Mock private Task task; + @Mock private ActionListener actionListener; + @Mock private AsyncQueryExecutorServiceImpl jobExecutorService; + + @Captor + private ArgumentCaptor createJobActionResponseArgumentCaptor; + + @Captor private ArgumentCaptor exceptionArgumentCaptor; + + @BeforeEach + public void setUp() { + action = + new TransportGetAsyncQueryResultAction( + transportService, new ActionFilters(new HashSet<>()), jobExecutorService); + } + + @Test + public void testDoExecute() { + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + new AsyncQueryExecutionResponse("IN_PROGRESS", null, null); + when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + " \"status\": \"IN_PROGRESS\"\n" + "}", + getAsyncQueryResultActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithSuccessResponse() { + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); + ExecutionEngine.Schema schema = + new ExecutionEngine.Schema( + ImmutableList.of( + new ExecutionEngine.Schema.Column("name", "name", STRING), + new ExecutionEngine.Schema.Column("age", "age", INTEGER))); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = + new AsyncQueryExecutionResponse( + "SUCCESS", + schema, + Arrays.asList( + tupleValue(ImmutableMap.of("name", "John", "age", 20)), + tupleValue(ImmutableMap.of("name", "Smith", "age", 30)))); + when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); + GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = + createJobActionResponseArgumentCaptor.getValue(); + Assertions.assertEquals( + "{\n" + + " \"status\": \"SUCCESS\",\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"name\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"age\",\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " \"John\",\n" + + " 20\n" + + " ],\n" + + " [\n" + + " \"Smith\",\n" + + " 30\n" + + " ]\n" + + " ],\n" + + " \"total\": 2,\n" + + " \"size\": 2\n" + + "}", + getAsyncQueryResultActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("123"); + doThrow(new AsyncQueryNotFoundException("JobId 123 not found")) + .when(jobExecutorService) + .getAsyncQueryResults("123"); + action.doExecute(task, request, actionListener); + verify(jobExecutorService, times(1)).getAsyncQueryResults("123"); + verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("JobId 123 not found", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java deleted file mode 100644 index 06d1ee8baf..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetJobRequestActionTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import java.util.HashSet; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.GetJobActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -@ExtendWith(MockitoExtension.class) -public class TransportGetJobRequestActionTest { - - @Mock private TransportService transportService; - @Mock private TransportGetJobRequestAction action; - @Mock private Task task; - @Mock private ActionListener actionListener; - - @Captor private ArgumentCaptor getJobActionResponseArgumentCaptor; - - @BeforeEach - public void setUp() { - action = new TransportGetJobRequestAction(transportService, new ActionFilters(new HashSet<>())); - } - - @Test - public void testDoExecuteWithSingleJob() { - GetJobActionRequest request = new GetJobActionRequest("abcd"); - - action.doExecute(task, request, actionListener); - Mockito.verify(actionListener).onResponse(getJobActionResponseArgumentCaptor.capture()); - GetJobActionResponse getJobActionResponse = getJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("Job abcd details.", getJobActionResponse.getResult()); - } - - @Test - public void testDoExecuteWithAllJobs() { - GetJobActionRequest request = new GetJobActionRequest(); - action.doExecute(task, request, actionListener); - Mockito.verify(actionListener).onResponse(getJobActionResponseArgumentCaptor.capture()); - GetJobActionResponse getJobActionResponse = getJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("All Jobs Information.", getJobActionResponse.getResult()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java deleted file mode 100644 index f22adead49..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportGetQueryResultRequestActionTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.transport; - -import java.util.HashSet; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.core.action.ActionListener; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionRequest; -import org.opensearch.sql.spark.transport.model.GetJobQueryResultActionResponse; -import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; - -@ExtendWith(MockitoExtension.class) -public class TransportGetQueryResultRequestActionTest { - - @Mock private TransportService transportService; - @Mock private TransportGetQueryResultRequestAction action; - @Mock private Task task; - @Mock private ActionListener actionListener; - - @Captor - private ArgumentCaptor createJobActionResponseArgumentCaptor; - - @BeforeEach - public void setUp() { - action = - new TransportGetQueryResultRequestAction( - transportService, new ActionFilters(new HashSet<>())); - } - - @Test - public void testDoExecuteForSingleJob() { - GetJobQueryResultActionRequest request = new GetJobQueryResultActionRequest("jobId"); - action.doExecute(task, request, actionListener); - Mockito.verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); - GetJobQueryResultActionResponse getJobQueryResultActionResponse = - createJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("job result", getJobQueryResultActionResponse.getResult()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java new file mode 100644 index 0000000000..5ba5627665 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -0,0 +1,40 @@ +package org.opensearch.sql.spark.transport.format; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.COMPACT; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; + +public class AsyncQueryResultResponseFormatterTest { + + private final ExecutionEngine.Schema schema = + new ExecutionEngine.Schema( + ImmutableList.of( + new ExecutionEngine.Schema.Column("firstname", null, STRING), + new ExecutionEngine.Schema.Column("age", null, INTEGER))); + + @Test + void formatAsyncQueryResponse() { + AsyncQueryResult response = + new AsyncQueryResult( + "success", + schema, + Arrays.asList( + tupleValue(ImmutableMap.of("firstname", "John", "age", 20)), + tupleValue(ImmutableMap.of("firstname", "Smith", "age", 30)))); + AsyncQueryResultResponseFormatter formatter = new AsyncQueryResultResponseFormatter(COMPACT); + assertEquals( + "{\"status\":\"success\",\"schema\":[{\"name\":\"firstname\",\"type\":\"string\"}," + + "{\"name\":\"age\",\"type\":\"integer\"}],\"datarows\":" + + "[[\"John\",20],[\"Smith\",30]],\"total\":2,\"size\":2}", + formatter.format(response)); + } +} From be8271455f210c148f6202672471aa47a3daaccc Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Thu, 21 Sep 2023 15:32:09 -0700 Subject: [PATCH 4/5] Cancel Job API (#2126) Signed-off-by: Vamsi Manohar --- .../rest/RestDataSourceQueryAction.java | 6 +- docs/user/interfaces/asyncqueryinterface.rst | 27 ++++--- .../asyncquery/AsyncQueryExecutorService.java | 8 ++ .../AsyncQueryExecutorServiceImpl.java | 11 +++ .../spark/client/EmrServerlessClientImpl.java | 20 +++++ .../sql/spark/client/SparkJobClient.java | 3 + .../dispatcher/SparkQueryDispatcher.java | 6 ++ .../rest/RestAsyncQueryManagementAction.java | 2 +- ...ransportCancelAsyncQueryRequestAction.java | 17 ++++- .../model/CancelAsyncQueryActionRequest.java | 2 + .../AsyncQueryExecutorServiceImplTest.java | 30 ++++++++ .../client/EmrServerlessClientImplTest.java | 29 +++++++ .../dispatcher/SparkQueryDispatcherTest.java | 76 +++++++++++-------- ...portCancelAsyncQueryRequestActionTest.java | 29 ++++++- 14 files changed, 213 insertions(+), 53 deletions(-) diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index b5929d0f20..2947afc5b9 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -88,8 +88,7 @@ public List routes() { new Route(GET, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * PUT datasources * Request body: * Ref * [org.opensearch.sql.plugin.transport.datasource.model.UpdateDataSourceActionRequest] @@ -100,8 +99,7 @@ public List routes() { new Route(PUT, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * DELETE datasources * Request body: Ref * [org.opensearch.sql.plugin.transport.datasource.model.DeleteDataSourceActionRequest] * Response body: Ref diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 98990b795b..f59afe8180 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -32,16 +32,13 @@ We make use of default aws credentials chain to make calls to the emr serverless have pass role permissions for emr-job-execution-role mentioned in the engine configuration. - Async Query Creation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. -HTTP URI: _plugins/_query/_async_query +HTTP URI: _plugins/_async_query HTTP VERB: POST - - Sample Request:: curl --location 'http://localhost:9200/_plugins/_async_query' \ @@ -57,23 +54,19 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } + Async Query Result API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/result``. Async Query Creation and Result Query permissions are orthogonal, so any user with result api permissions and queryId can query the corresponding query results irrespective of the user who created the async query. - -HTTP URI: _plugins/_query/_async_query/{queryId} +HTTP URI: _plugins/_async_query/{queryId} HTTP VERB: GET - Sample Request BODY:: curl --location --request GET 'http://localhost:9200/_plugins/_async_query/00fd796ut1a7eg0q' \ --header 'Content-Type: application/json' \ - --data '{ - "query" : "select * from default.http_logs limit 1" - }' Sample Response if the Query is in Progress :: @@ -106,3 +99,17 @@ Sample Response If the Query is successful :: "total": 1, "size": 1 } + + +Async Query Cancellation API +====================================== +If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/jobs/delete``. + +HTTP URI: _plugins/_async_query/{queryId} +HTTP VERB: DELETE + +Sample Request Body :: + + curl --location --request DELETE 'http://localhost:9200/_plugins/_async_query/00fdalrvgkbh2g0q' \ + --header 'Content-Type: application/json' \ + diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index df13daa2a2..7caa69293a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -29,4 +29,12 @@ public interface AsyncQueryExecutorService { * @return {@link AsyncQueryExecutionResponse} */ AsyncQueryExecutionResponse getAsyncQueryResults(String queryId); + + /** + * Cancels running async query and returns the cancelled queryId. + * + * @param queryId queryId. + * @return {@link String} cancelledQueryId. + */ + String cancelQuery(String queryId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index e5ed65920e..efc23e08b5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -95,6 +95,17 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } + @Override + public String cancelQuery(String queryId) { + Optional asyncQueryJobMetadata = + asyncQueryJobMetadataStorageService.getJobMetadata(queryId); + if (asyncQueryJobMetadata.isPresent()) { + return sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata.get().getApplicationId(), queryId); + } + throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); + } + private void validateSparkExecutionEngineSettings() { if (!isSparkJobExecutionEnabled) { throw new IllegalArgumentException( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index b554c4cd23..2377b2f5da 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -9,12 +9,15 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunRequest; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunRequest; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobDriver; import com.amazonaws.services.emrserverless.model.SparkSubmit; import com.amazonaws.services.emrserverless.model.StartJobRunRequest; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; import java.security.AccessController; import java.security.PrivilegedAction; import org.apache.logging.log4j.LogManager; @@ -65,4 +68,21 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); return getJobRunResult; } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + CancelJobRunRequest cancelJobRunRequest = + new CancelJobRunRequest().withJobRunId(jobId).withApplicationId(applicationId); + try { + CancelJobRunResult cancelJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) + () -> emrServerless.cancelJobRun(cancelJobRunRequest)); + logger.info(String.format("Job : %s cancelled", cancelJobRunResult.getJobRunId())); + return cancelJobRunResult; + } catch (ValidationException e) { + throw new IllegalArgumentException( + String.format("Couldn't cancel the queryId: %s due to %s", jobId, e.getMessage())); + } + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java index ff9f4acedd..c6b3059c77 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java @@ -7,6 +7,7 @@ package org.opensearch.sql.spark.client; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; public interface SparkJobClient { @@ -19,4 +20,6 @@ String startJobRun( String sparkSubmitParams); GetJobRunResult getJobRunResult(String applicationId, String jobId); + + CancelJobRunResult cancelJobRun(String applicationId, String jobId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index f632ceaf6a..442838331f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.net.URI; @@ -64,6 +65,11 @@ public JSONObject getQueryResponse(String applicationId, String queryId) { return result; } + public String cancelJob(String applicationId, String jobId) { + CancelJobRunResult cancelJobRunResult = sparkJobClient.cancelJobRun(applicationId, jobId); + return cancelJobRunResult.getJobRunId(); + } + // TODO: Analyze given query // Extract datasourceName // Apply Authorizaiton. diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 56484688dc..741501cd18 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -194,7 +194,7 @@ public void onResponse( CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse) { restChannel.sendResponse( new BytesRestResponse( - RestStatus.OK, + RestStatus.NO_CONTENT, "application/json; charset=UTF-8", cancelAsyncQueryActionResponse.getResult())); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 990dbccd0b..232a280db5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -12,6 +12,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -21,13 +22,17 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; + private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @Inject public TransportCancelAsyncQueryRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); + this.asyncQueryExecutorService = asyncQueryExecutorService; } @Override @@ -35,7 +40,13 @@ protected void doExecute( Task task, CancelAsyncQueryActionRequest request, ActionListener listener) { - String responseContent = "deleted_job"; - listener.onResponse(new CancelAsyncQueryActionResponse(responseContent)); + try { + String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + listener.onResponse( + new CancelAsyncQueryActionResponse( + String.format("Deleted async query with id: %s", jobId))); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index e12f184efe..0065b575ed 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -9,11 +9,13 @@ import java.io.IOException; import lombok.AllArgsConstructor; +import lombok.Getter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; @AllArgsConstructor +@Getter public class CancelAsyncQueryActionRequest extends ActionRequest { private String queryId; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index cf04278892..5e832777fc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -142,4 +142,34 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() { + " to enable Async Query APIs", illegalArgumentException.getMessage()); } + + @Test + void testCancelJobWithJobNotFound() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)); + Assertions.assertEquals( + "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + verifyNoInteractions(sparkQueryDispatcher); + verifyNoInteractions(settings); + } + + @Test + void testCancelJob() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); + verifyNoInteractions(settings); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 36f10cd08b..925ee73bcd 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -5,17 +5,22 @@ package org.opensearch.sql.spark.client; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -45,4 +50,28 @@ void testGetJobRunState() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } + + @Test + void testCancelJobRun() { + when(emrServerless.cancelJobRun(any())) + .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); + } + + @Test + void testCancelJobRunWithValidationException() { + doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)); + Assertions.assertEquals( + "Couldn't cancel the queryId: job-123xxx due to Error (Service: null; Status Code: 0; Error" + + " Code: null; Request ID: null; Proxy: null)", + illegalArgumentException.getMessage()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 800bd59b72..2000eeefed 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -14,6 +14,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -79,36 +80,17 @@ void testDispatchWithWrongURI() { illegalArgumentException.getMessage()); } - private DataSourceMetadata constructMyGlueDataSourceMetadata() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put( - "glue.indexstore.opensearch.uri", - "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; - } - - private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; + @Test + void testCancelJob() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test @@ -140,7 +122,7 @@ void testGetQueryResponseWithSuccess() { Assertions.assertEquals("SUCCESS", result.get("status")); } - String constructExpectedSparkSubmitParameterString() { + private String constructExpectedSparkSubmitParameterString() { return " --class org.opensearch.sql.FlintJob --conf" + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + " --conf" @@ -171,4 +153,36 @@ String constructExpectedSparkSubmitParameterString() { + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; } + + private DataSourceMetadata constructMyGlueDataSourceMetadata() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put( + "glue.indexstore.opensearch.uri", + "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index c560c882c0..2ff76b9b57 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,10 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; + import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -19,6 +23,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -32,24 +37,40 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + @Captor private ArgumentCaptor deleteJobActionResponseArgumentCaptor; + @Captor private ArgumentCaptor exceptionArgumentCaptor; + @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>())); + transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); } @Test public void testDoExecute() { - CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest("jobId"); - + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("deleted_job", cancelAsyncQueryActionResponse.getResult()); + Assertions.assertEquals( + "Deleted async query with id: " + EMR_JOB_ID, cancelAsyncQueryActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); } } From e3c4ee6c85182a659660bf8776c2a39bf9296cd2 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 27 Sep 2023 11:16:03 -0700 Subject: [PATCH 5/5] Add tags to the emr jobs based on the query types (#2150) Signed-off-by: Vamsi Manohar --- docs/user/interfaces/asyncqueryinterface.rst | 18 +- .../org/opensearch/sql/plugin/SQLPlugin.java | 17 +- spark/build.gradle | 36 +- .../src/main/antlr/FlintSparkSqlExtensions.g4 | 91 + spark/src/main/antlr/SparkSqlBase.g4 | 223 ++ spark/src/main/antlr/SqlBaseLexer.g4 | 551 +++++ spark/src/main/antlr/SqlBaseParser.g4 | 1905 +++++++++++++++++ .../AsyncQueryExecutorServiceImpl.java | 13 +- .../sql/spark/client/EMRServerlessClient.java | 45 + ...l.java => EmrServerlessClientImplEMR.java} | 25 +- .../sql/spark/client/SparkJobClient.java | 25 - .../sql/spark/client/StartJobRequest.java | 23 + .../dispatcher/SparkQueryDispatcher.java | 153 +- .../model/DispatchQueryRequest.java | 18 + .../model/FullyQualifiedTableName.java | 43 + .../spark/dispatcher/model/IndexDetails.java | 15 + .../rest/model/CreateAsyncQueryRequest.java | 11 +- .../sql/spark/rest/model/LangType.java | 36 + .../sql/spark/utils/SQLQueryUtils.java | 136 ++ .../AsyncQueryExecutorServiceImplTest.java | 27 +- .../client/EmrServerlessClientImplTest.java | 17 +- .../sql/spark/constants/TestConstants.java | 1 + .../dispatcher/SparkQueryDispatcherTest.java | 302 ++- ...portCreateAsyncQueryRequestActionTest.java | 5 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 110 + 25 files changed, 3713 insertions(+), 133 deletions(-) create mode 100644 spark/src/main/antlr/FlintSparkSqlExtensions.g4 create mode 100644 spark/src/main/antlr/SparkSqlBase.g4 create mode 100644 spark/src/main/antlr/SqlBaseLexer.g4 create mode 100644 spark/src/main/antlr/SqlBaseParser.g4 create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClient.java rename spark/src/main/java/org/opensearch/sql/spark/client/{EmrServerlessClientImpl.java => EmrServerlessClientImplEMR.java} (82%) delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index f59afe8180..89529c8c82 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -36,15 +36,16 @@ Async Query Creation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. -HTTP URI: _plugins/_async_query -HTTP VERB: POST +HTTP URI: ``_plugins/_async_query`` + +HTTP VERB: ``POST`` Sample Request:: curl --location 'http://localhost:9200/_plugins/_async_query' \ --header 'Content-Type: application/json' \ --data '{ - "kind" : "sql", + "lang" : "sql", "query" : "select * from my_glue.default.http_logs limit 10" }' @@ -60,8 +61,9 @@ Async Query Result API If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/result``. Async Query Creation and Result Query permissions are orthogonal, so any user with result api permissions and queryId can query the corresponding query results irrespective of the user who created the async query. -HTTP URI: _plugins/_async_query/{queryId} -HTTP VERB: GET +HTTP URI: ``_plugins/_async_query/{queryId}`` + +HTTP VERB: ``GET`` Sample Request BODY:: @@ -75,6 +77,7 @@ Sample Response if the Query is in Progress :: Sample Response If the Query is successful :: { + "status": "SUCCESS", "schema": [ { "name": "indexed_col_name", @@ -105,8 +108,9 @@ Async Query Cancellation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/jobs/delete``. -HTTP URI: _plugins/_async_query/{queryId} -HTTP VERB: DELETE +HTTP URI: ``_plugins/_async_query/{queryId}`` + +HTTP VERB: ``DELETE`` Sample Request Body :: diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index ed10b1e3e6..d5100885c4 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -93,8 +93,8 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.client.EmrServerlessClientImpl; -import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EmrServerlessClientImplEMR; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -297,20 +297,23 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService() { AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); - SparkJobClient sparkJobClient = createEMRServerlessClient(); + EMRServerlessClient EMRServerlessClient = createEMRServerlessClient(); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - sparkJobClient, this.dataSourceService, jobExecutionResponseReader); + EMRServerlessClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings); } - private SparkJobClient createEMRServerlessClient() { + private EMRServerlessClient createEMRServerlessClient() { String sparkExecutionEngineConfigString = this.pluginSettings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); return AccessController.doPrivileged( - (PrivilegedAction) + (PrivilegedAction) () -> { SparkExecutionEngineConfig sparkExecutionEngineConfig = SparkExecutionEngineConfig.toSparkExecutionEngineConfig( @@ -320,7 +323,7 @@ private SparkJobClient createEMRServerlessClient() { .withRegion(sparkExecutionEngineConfig.getRegion()) .withCredentials(new DefaultAWSCredentialsProviderChain()) .build(); - return new EmrServerlessClientImpl(awsemrServerless); + return new EmrServerlessClientImplEMR(awsemrServerless); }); } } diff --git a/spark/build.gradle b/spark/build.gradle index fb9a1e0e4b..2bee7408a5 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -7,13 +7,42 @@ plugins { id 'java-library' id "io.freefair.lombok" id 'jacoco' + id 'antlr' } repositories { mavenCentral() } +tasks.register('downloadG4Files', Exec) { + description = 'Download remote .g4 files from GitHub' + + executable 'curl' + +// Need to add these back once the grammar issues with indexName and tableName is addressed in flint integration jar. +// args '-o', 'src/main/antlr/FlintSparkSqlExtensions.g4', 'https://raw.githubusercontent.com/opensearch-project/opensearch-spark/main/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4' +// args '-o', 'src/main/antlr/SparkSqlBase.g4', 'https://raw.githubusercontent.com/opensearch-project/opensearch-spark/main/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4' + args '-o', 'src/main/antlr/SqlBaseParser.g4', 'https://raw.githubusercontent.com/apache/spark/master/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4' + args '-o', 'src/main/antlr/SqlBaseLexer.g4', 'https://raw.githubusercontent.com/apache/spark/master/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4' +} + +generateGrammarSource { + arguments += ['-visitor', '-package', 'org.opensearch.sql.spark.antlr.parser'] + source = sourceSets.main.antlr + outputDirectory = file("build/generated-src/antlr/main/org/opensearch/sql/spark/antlr/parser") +} +configurations { + compile { + extendsFrom = extendsFrom.findAll { it != configurations.antlr } + } +} + +// Make sure the downloadG4File task runs before the generateGrammarSource task +generateGrammarSource.dependsOn downloadG4Files + dependencies { + antlr "org.antlr:antlr4:4.7.1" + api project(':core') implementation project(':protocol') implementation project(':datasources') @@ -46,7 +75,7 @@ jacocoTestReport { } afterEvaluate { classDirectories.setFrom(files(classDirectories.files.collect { - fileTree(dir: it) + fileTree(dir: it, exclude: ['**/antlr/parser/**']) })) } } @@ -61,7 +90,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.rest.*', 'org.opensearch.sql.spark.transport.model.*', 'org.opensearch.sql.spark.asyncquery.model.*', - 'org.opensearch.sql.spark.asyncquery.exceptions.*' + 'org.opensearch.sql.spark.asyncquery.exceptions.*', + 'org.opensearch.sql.spark.dispatcher.model.*' ] limit { counter = 'LINE' @@ -75,7 +105,7 @@ jacocoTestCoverageVerification { } afterEvaluate { classDirectories.setFrom(files(classDirectories.files.collect { - fileTree(dir: it) + fileTree(dir: it, exclude: ['**/antlr/parser/**']) })) } } diff --git a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 new file mode 100644 index 0000000000..2d50fbc49f --- /dev/null +++ b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +grammar FlintSparkSqlExtensions; + +import SparkSqlBase; + + +// Flint SQL Syntax Extension + +singleStatement + : statement SEMICOLON* EOF + ; + +statement + : skippingIndexStatement + | coveringIndexStatement + ; + +skippingIndexStatement + : createSkippingIndexStatement + | refreshSkippingIndexStatement + | describeSkippingIndexStatement + | dropSkippingIndexStatement + ; + +createSkippingIndexStatement + : CREATE SKIPPING INDEX ON tableName + LEFT_PAREN indexColTypeList RIGHT_PAREN + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +refreshSkippingIndexStatement + : REFRESH SKIPPING INDEX ON tableName + ; + +describeSkippingIndexStatement + : (DESC | DESCRIBE) SKIPPING INDEX ON tableName + ; + +dropSkippingIndexStatement + : DROP SKIPPING INDEX ON tableName + ; + +coveringIndexStatement + : createCoveringIndexStatement + | refreshCoveringIndexStatement + | showCoveringIndexStatement + | describeCoveringIndexStatement + | dropCoveringIndexStatement + ; + +createCoveringIndexStatement + : CREATE INDEX indexName ON tableName + LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +refreshCoveringIndexStatement + : REFRESH INDEX indexName ON tableName + ; + +showCoveringIndexStatement + : SHOW (INDEX | INDEXES) ON tableName + ; + +describeCoveringIndexStatement + : (DESC | DESCRIBE) INDEX indexName ON tableName + ; + +dropCoveringIndexStatement + : DROP INDEX indexName ON tableName + ; + +indexColTypeList + : indexColType (COMMA indexColType)* + ; + +indexColType + : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) + ; + +indexName + : identifier + ; + +tableName + : multipartIdentifier + ; diff --git a/spark/src/main/antlr/SparkSqlBase.g4 b/spark/src/main/antlr/SparkSqlBase.g4 new file mode 100644 index 0000000000..928f63812c --- /dev/null +++ b/spark/src/main/antlr/SparkSqlBase.g4 @@ -0,0 +1,223 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar SparkSqlBase; + +// Copy from Spark 3.3.1 SqlBaseParser.g4 and SqlBaseLexer.g4 + +@members { + /** + * When true, parser should throw ParseExcetion for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (options=propertyList)? + ; + +propertyList + : property (COMMA property)* + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | STRING + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +booleanValue + : TRUE | FALSE + ; + + +multipartIdentifier + : parts+=identifier (DOT parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +nonReserved + : DROP | SKIPPING | INDEX + ; + + +// Flint lexical tokens + +MIN_MAX: 'MIN_MAX'; +SKIPPING: 'SKIPPING'; +VALUE_SET: 'VALUE_SET'; + + +// Spark lexical tokens + +SEMICOLON: ';'; + +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; + + +CREATE: 'CREATE'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DROP: 'DROP'; +FALSE: 'FALSE'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +ON: 'ON'; +PARTITION: 'PARTITION'; +REFRESH: 'REFRESH'; +SHOW: 'SHOW'; +STRING: 'STRING'; +TRUE: 'TRUE'; +WITH: 'WITH'; + + +EQ : '=' | '=='; +MINUS: '-'; + + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; \ No newline at end of file diff --git a/spark/src/main/antlr/SqlBaseLexer.g4 b/spark/src/main/antlr/SqlBaseLexer.g4 new file mode 100644 index 0000000000..d9128de0f5 --- /dev/null +++ b/spark/src/main/antlr/SqlBaseLexer.g4 @@ -0,0 +1,551 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +lexer grammar SqlBaseLexer; + +@members { + /** + * When true, parser should throw ParseException for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + +SEMICOLON: ';'; + +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; +LEFT_BRACKET: '['; +RIGHT_BRACKET: ']'; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`, and +// modify `ParserUtils.toExprAlias()` which assumes all keywords are between `ADD` and `ZONE`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ALWAYS: 'ALWAYS'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ANY_VALUE: 'ANY_VALUE'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BIGINT: 'BIGINT'; +BINARY: 'BINARY'; +BOOLEAN: 'BOOLEAN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +BYTE: 'BYTE'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CATALOG: 'CATALOG'; +CATALOGS: 'CATALOGS'; +CHANGE: 'CHANGE'; +CHAR: 'CHAR'; +CHARACTER: 'CHARACTER'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DAYS: 'DAYS'; +DAYOFYEAR: 'DAYOFYEAR'; +DATA: 'DATA'; +DATE: 'DATE'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES'; +DATEADD: 'DATEADD'; +DATE_ADD: 'DATE_ADD'; +DATEDIFF: 'DATEDIFF'; +DATE_DIFF: 'DATE_DIFF'; +DBPROPERTIES: 'DBPROPERTIES'; +DEC: 'DEC'; +DECIMAL: 'DECIMAL'; +DECLARE: 'DECLARE'; +DEFAULT: 'DEFAULT'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DOUBLE: 'DOUBLE'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXCLUDE: 'EXCLUDE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FLOAT: 'FLOAT'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GENERATED: 'GENERATED'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +BINARY_HEX: 'X'; +HOUR: 'HOUR'; +HOURS: 'HOURS'; +IDENTIFIER_KW: 'IDENTIFIER'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INCLUDE: 'INCLUDE'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INT: 'INT'; +INTEGER: 'INTEGER'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +ILIKE: 'ILIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +LONG: 'LONG'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MICROSECOND: 'MICROSECOND'; +MICROSECONDS: 'MICROSECONDS'; +MILLISECOND: 'MILLISECOND'; +MILLISECONDS: 'MILLISECONDS'; +MINUTE: 'MINUTE'; +MINUTES: 'MINUTES'; +MONTH: 'MONTH'; +MONTHS: 'MONTHS'; +MSCK: 'MSCK'; +NAME: 'NAME'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NANOSECOND: 'NANOSECOND'; +NANOSECONDS: 'NANOSECONDS'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +NUMERIC: 'NUMERIC'; +OF: 'OF'; +OFFSET: 'OFFSET'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTILE_CONT: 'PERCENTILE_CONT'; +PERCENTILE_DISC: 'PERCENTILE_DISC'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUARTER: 'QUARTER'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +REAL: 'REAL'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPEATABLE: 'REPEATABLE'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SECONDS: 'SECONDS'; +SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHORT: 'SHORT'; +SHOW: 'SHOW'; +SINGLE: 'SINGLE'; +SKEWED: 'SKEWED'; +SMALLINT: 'SMALLINT'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +SOURCE: 'SOURCE'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRING: 'STRING'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +SYSTEM_TIME: 'SYSTEM_TIME'; +SYSTEM_VERSION: 'SYSTEM_VERSION'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TARGET: 'TARGET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TIMEDIFF: 'TIMEDIFF'; +TIMESTAMP: 'TIMESTAMP'; +TIMESTAMP_LTZ: 'TIMESTAMP_LTZ'; +TIMESTAMP_NTZ: 'TIMESTAMP_NTZ'; +TIMESTAMPADD: 'TIMESTAMPADD'; +TIMESTAMPDIFF: 'TIMESTAMPDIFF'; +TINYINT: 'TINYINT'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNPIVOT: 'UNPIVOT'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VARCHAR: 'VARCHAR'; +VAR: 'VAR'; +VARIABLE: 'VARIABLE'; +VERSION: 'VERSION'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +VOID: 'VOID'; +WEEK: 'WEEK'; +WEEKS: 'WEEKS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +WITHIN: 'WITHIN'; +YEAR: 'YEAR'; +YEARS: 'YEARS'; +ZONE: 'ZONE'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; +COLON: ':'; +ARROW: '->'; +FAT_ARROW : '=>'; +HENT_START: '/*+'; +HENT_END: '*/'; +QUESTION: '?'; + +STRING_LITERAL + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | 'R\'' (~'\'')* '\'' + | 'R"'(~'"')* '"' + ; + +DOUBLEQUOTED_STRING + :'"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +// NOTE: If you move a numeric literal, you should modify `ParserUtils.toExprAlias()` +// which assumes all numeric literals are between `BIGINT_LITERAL` and `BIGDECIMAL_LITERAL`. + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 new file mode 100644 index 0000000000..6a6d39e96c --- /dev/null +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -0,0 +1,1905 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +parser grammar SqlBaseParser; + +options { tokenVocab = SqlBaseLexer; } + +@members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enabled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; + + /** + * When true, double quoted literals are identifiers rather than STRINGs. + */ + public boolean double_quoted_identifiers = false; +} + +singleStatement + : statement SEMICOLON* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE identifierReference #use + | USE namespace identifierReference #useNamespace + | SET CATALOG (identifier | stringLit) #setCatalog + | CREATE namespace (IF NOT EXISTS)? identifierReference + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) propertyList))* #createNamespace + | ALTER namespace identifierReference + SET (DBPROPERTIES | PROPERTIES) propertyList #setNamespaceProperties + | ALTER namespace identifierReference + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? identifierReference + (RESTRICT | CASCADE)? #dropNamespace + | SHOW namespaces ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=stringLit)? #showNamespaces + | createTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=propertyList))* #createTableLike + | replaceTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) identifierReference)? COMPUTE STATISTICS + (identifier)? #analyzeTables + | ALTER TABLE identifierReference + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE identifierReference + ADD (COLUMN | COLUMNS) + LEFT_PAREN columns=qualifiedColTypeWithPositionList RIGHT_PAREN #addTableColumns + | ALTER TABLE table=identifierReference + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE identifierReference + DROP (COLUMN | COLUMNS) (IF EXISTS)? + LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN #dropTableColumns + | ALTER TABLE identifierReference + DROP (COLUMN | COLUMNS) (IF EXISTS)? + columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=identifierReference + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) identifierReference + SET TBLPROPERTIES propertyList #setTableProperties + | ALTER (TABLE | VIEW) identifierReference + UNSET TBLPROPERTIES (IF EXISTS)? propertyList #unsetTableProperties + | ALTER TABLE table=identifierReference + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=identifierReference partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=identifierReference partitionSpec? + REPLACE COLUMNS + LEFT_PAREN columns=qualifiedColTypeWithPositionList + RIGHT_PAREN #hiveReplaceColumns + | ALTER TABLE identifierReference (partitionSpec)? + SET SERDE stringLit (WITH SERDEPROPERTIES propertyList)? #setTableSerDe + | ALTER TABLE identifierReference (partitionSpec)? + SET SERDEPROPERTIES propertyList #setTableSerDe + | ALTER (TABLE | VIEW) identifierReference ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE identifierReference + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) identifierReference + DROP (IF EXISTS)? partitionSpec (COMMA partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE identifierReference + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE identifierReference RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable + | DROP VIEW (IF EXISTS)? identifierReference #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? identifierReference + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES propertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier (LEFT_PAREN colTypeList RIGHT_PAREN)? tableProvider + (OPTIONS propertyList)? #createTempViewUsing + | ALTER VIEW identifierReference AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + identifierReference AS className=stringLit + (USING resource (COMMA resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction + | DECLARE (OR REPLACE)? VARIABLE? + identifierReference dataType? variableDefaultExpression? #createVariable + | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) identifierReference)? + (LIKE? pattern=stringLit)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)? + LIKE pattern=stringLit partitionSpec? #showTableExtended + | SHOW TBLPROPERTIES table=identifierReference + (LEFT_PAREN key=propertyKey RIGHT_PAREN)? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=identifierReference + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) identifierReference)? + (LIKE? pattern=stringLit)? #showViews + | SHOW PARTITIONS identifierReference partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS ((FROM | IN) ns=identifierReference)? + (LIKE? (legacy=multipartIdentifier | pattern=stringLit))? #showFunctions + | SHOW CREATE TABLE identifierReference (AS SERDE)? #showCreateTable + | SHOW CURRENT namespace #showCurrentNamespace + | SHOW CATALOGS (LIKE? pattern=stringLit)? #showCatalogs + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + identifierReference #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + identifierReference partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace identifierReference IS + comment #commentNamespace + | COMMENT ON TABLE identifierReference IS comment #commentTable + | REFRESH TABLE identifierReference #refreshTable + | REFRESH FUNCTION identifierReference #refreshFunction + | REFRESH (stringLit | .*?) #refreshResource + | CACHE LAZY? TABLE identifierReference + (OPTIONS options=propertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? identifierReference #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=stringLit OVERWRITE? INTO TABLE + identifierReference partitionSpec? #loadData + | TRUNCATE TABLE identifierReference partitionSpec? #truncateTable + | (MSCK)? REPAIR TABLE identifierReference + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable + | op=(ADD | LIST) identifier .*? #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET (VARIABLE | VAR) assignmentList #setVariable + | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + LEFT_PAREN query RIGHT_PAREN #setVariable + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? + identifierReference (USING indexType=identifier)? + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN + (OPTIONS options=propertyList)? #createIndex + | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +timezone + : stringLit + | LOCAL + ; + +configKey + : quotedIdentifier + ; + +configValue + : backQuotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? identifierReference + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE identifierReference + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION stringLit + ; + +commentSpec + : COMMENT stringLit + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF NOT EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable + | INSERT INTO TABLE? identifierReference partitionSpec? (IF NOT EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable + | INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere + | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION LEFT_PAREN partitionVal (COMMA partitionVal)* RIGHT_PAREN + ; + +partitionVal + : identifier (EQ constant)? + | identifier EQ DEFAULT + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +namespaces + : NAMESPACES + | DATABASES + | SCHEMAS + ; + +describeFuncName + : identifierReference + | stringLit + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier (DOT nameParts+=identifier)* + ; + +ctes + : WITH namedQuery (COMMA namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? LEFT_PAREN query RIGHT_PAREN + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=expressionPropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=propertyList))* + ; + +propertyList + : LEFT_PAREN property (COMMA property)* RIGHT_PAREN + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | stringLit + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | stringLit + ; + +expressionPropertyList + : LEFT_PAREN expressionProperty (COMMA expressionProperty)* RIGHT_PAREN + ; + +expressionProperty + : key=propertyKey (EQ? value=expression)? + ; + +constantList + : LEFT_PAREN constant (COMMA constant)* RIGHT_PAREN + ; + +nestedConstantList + : LEFT_PAREN constantList (COMMA constantList)* RIGHT_PAREN + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=stringLit OUTPUTFORMAT outFmt=stringLit #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : stringLit (WITH SERDEPROPERTIES propertyList)? + ; + +resource + : identifier stringLit + ; + +dmlStatementNoWith + : insertInto query #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM identifierReference tableAlias whereClause? #deleteFromTable + | UPDATE identifierReference tableAlias setClause whereClause? #updateTable + | MERGE INTO target=identifierReference targetAlias=tableAlias + USING (source=identifierReference | + LEFT_PAREN sourceQuery=query RIGHT_PAREN) sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* + notMatchedBySourceClause* #mergeIntoTable + ; + +identifierReference + : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | multipartIdentifier + ; + +queryOrganization + : (ORDER BY order+=sortItem (COMMA order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (COMMA distributeBy+=expression)*)? + (SORT BY sort+=sortItem (COMMA sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + (OFFSET offset=expression)? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enabled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE identifierReference #table + | inlineTable #inlineTableDefault1 + | LEFT_PAREN query RIGHT_PAREN #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM LEFT_PAREN setQuantifier? expressionSeq RIGHT_PAREN + | kind=MAP setQuantifier? expressionSeq + | kind=REDUCE setQuantifier? expressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=stringLit)? + USING script=stringLit + (AS (identifierSeq | colTypeList | (LEFT_PAREN (identifierSeq | colTypeList) RIGHT_PAREN)))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=stringLit)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (BY TARGET)? (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +notMatchedBySourceClause + : WHEN NOT MATCHED BY SOURCE (AND notMatchedBySourceCond=booleanExpression)? THEN notMatchedBySourceAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN + VALUES LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + ; + +notMatchedBySourceAction + : DELETE + | UPDATE SET assignmentList + ; + +assignmentList + : assignment (COMMA assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : HENT_START hintStatements+=hintStatement (COMMA? hintStatements+=hintStatement)* HENT_END + ; + +hintStatement + : hintName=identifier + | hintName=identifier LEFT_PAREN parameters+=primaryExpression (COMMA parameters+=primaryExpression)* RIGHT_PAREN + ; + +fromClause + : FROM relation (COMMA relation)* lateralView* pivotClause? unpivotClause? + ; + +temporalClause + : FOR? (SYSTEM_VERSION | VERSION) AS OF version + | FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + ; + +aggregationClause + : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause + (COMMA groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (COMMA groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN)? + ; + +groupByClause + : groupingAnalytics + | expression + ; + +groupingAnalytics + : (ROLLUP | CUBE) LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN + | GROUPING SETS LEFT_PAREN groupingElement (COMMA groupingElement)* RIGHT_PAREN + ; + +groupingElement + : groupingAnalytics + | groupingSet + ; + +groupingSet + : LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN + | expression + ; + +pivotClause + : PIVOT LEFT_PAREN aggregates=namedExpressionSeq FOR pivotColumn IN LEFT_PAREN pivotValues+=pivotValue (COMMA pivotValues+=pivotValue)* RIGHT_PAREN RIGHT_PAREN + ; + +pivotColumn + : identifiers+=identifier + | LEFT_PAREN identifiers+=identifier (COMMA identifiers+=identifier)* RIGHT_PAREN + ; + +pivotValue + : expression (AS? identifier)? + ; + +unpivotClause + : UNPIVOT nullOperator=unpivotNullClause? LEFT_PAREN + operator=unpivotOperator + RIGHT_PAREN (AS? identifier)? + ; + +unpivotNullClause + : (INCLUDE | EXCLUDE) NULLS + ; + +unpivotOperator + : (unpivotSingleValueColumnClause | unpivotMultiValueColumnClause) + ; + +unpivotSingleValueColumnClause + : unpivotValueColumn FOR unpivotNameColumn IN LEFT_PAREN unpivotColumns+=unpivotColumnAndAlias (COMMA unpivotColumns+=unpivotColumnAndAlias)* RIGHT_PAREN + ; + +unpivotMultiValueColumnClause + : LEFT_PAREN unpivotValueColumns+=unpivotValueColumn (COMMA unpivotValueColumns+=unpivotValueColumn)* RIGHT_PAREN + FOR unpivotNameColumn + IN LEFT_PAREN unpivotColumnSets+=unpivotColumnSet (COMMA unpivotColumnSets+=unpivotColumnSet)* RIGHT_PAREN + ; + +unpivotColumnSet + : LEFT_PAREN unpivotColumns+=unpivotColumn (COMMA unpivotColumns+=unpivotColumn)* RIGHT_PAREN unpivotAlias? + ; + +unpivotValueColumn + : identifier + ; + +unpivotNameColumn + : identifier + ; + +unpivotColumnAndAlias + : unpivotColumn unpivotAlias? + ; + +unpivotColumn + : multipartIdentifier + ; + +unpivotAlias + : AS? identifier + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN tblName=identifier (AS? colName+=identifier (COMMA colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : LATERAL? relationPrimary relationExtension* + ; + +relationExtension + : joinRelation + | pivotClause + | unpivotClause + ; + +joinRelation + : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + | NATURAL joinType JOIN LATERAL? right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=INTEGER_VALUE RIGHT_PAREN)? + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName LEFT_PAREN RIGHT_PAREN))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : LEFT_PAREN identifierSeq RIGHT_PAREN + ; + +identifierSeq + : ident+=errorCapturingIdentifier (COMMA ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : LEFT_PAREN orderedIdentifier (COMMA orderedIdentifier)* RIGHT_PAREN + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : LEFT_PAREN identifierComment (COMMA identifierComment)* RIGHT_PAREN + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : identifierReference temporalClause? + sample? tableAlias #tableName + | LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery + | LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (COMMA expression)* tableAlias + ; + +functionTableSubqueryArgument + : TABLE identifierReference tableArgumentPartitioning? + | TABLE LEFT_PAREN identifierReference RIGHT_PAREN tableArgumentPartitioning? + | TABLE LEFT_PAREN query RIGHT_PAREN tableArgumentPartitioning? + ; + +tableArgumentPartitioning + : ((WITH SINGLE PARTITION) + | ((PARTITION | DISTRIBUTE) BY + (((LEFT_PAREN partition+=expression (COMMA partition+=expression)* RIGHT_PAREN)) + | partition+=expression))) + ((ORDER | SORT) BY + (((LEFT_PAREN sortItem (COMMA sortItem)* RIGHT_PAREN) + | sortItem)))? + ; + +functionTableNamedArgumentExpression + : key=identifier FAT_ARROW table=functionTableSubqueryArgument + ; + +functionTableReferenceArgument + : functionTableSubqueryArgument + | functionTableNamedArgumentExpression + ; + +functionTableArgument + : functionTableReferenceArgument + | functionArgument + ; + +functionTable + : funcName=functionName LEFT_PAREN + (functionTableArgument (COMMA functionTableArgument)*)? + RIGHT_PAREN tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=stringLit (WITH SERDEPROPERTIES props=propertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=stringLit (ESCAPED BY escapedBy=stringLit)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=stringLit)? + (MAP KEYS TERMINATED BY keysTerminatedBy=stringLit)? + (LINES TERMINATED BY linesSeparatedBy=stringLit)? + (NULL DEFINED AS nullDefinedAs=stringLit)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (COMMA multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier (DOT parts+=errorCapturingIdentifier)* + ; + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (OPTIONS options=propertyList)? + ; + +tableIdentifier + : (db=errorCapturingIdentifier DOT)? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier DOT)? function=errorCapturingIdentifier + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (COMMA namedExpression)* + ; + +partitionFieldList + : LEFT_PAREN fields+=partitionField (COMMA fields+=partitionField)* RIGHT_PAREN + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + LEFT_PAREN argument+=transformArgument (COMMA argument+=transformArgument)* RIGHT_PAREN #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +namedArgumentExpression + : key=identifier FAT_ARROW value=expression + ; + +functionArgument + : expression + | namedArgumentExpression + ; + +expressionSeq + : expression (COMMA expression)* + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS LEFT_PAREN query RIGHT_PAREN #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + | NOT? kind=IN LEFT_PAREN query RIGHT_PAREN + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) (LEFT_PAREN RIGHT_PAREN | LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN) + | NOT? kind=(LIKE | ILIKE) pattern=valueExpression (ESCAPE escapeChar=stringLit)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +datetimeUnit + : YEAR | QUARTER | MONTH + | WEEK | DAY | DAYOFYEAR + | HOUR | MINUTE | SECOND | MILLISECOND | MICROSECOND + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER | SESSION_USER) #currentLike + | name=(TIMESTAMPADD | DATEADD | DATE_ADD) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA unitsAmount=valueExpression COMMA timestamp=valueExpression RIGHT_PAREN #timestampadd + | name=(TIMESTAMPDIFF | DATEDIFF | DATE_DIFF | TIMEDIFF) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA startTimestamp=valueExpression COMMA endTimestamp=valueExpression RIGHT_PAREN #timestampdiff + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct + | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first + | ANY_VALUE LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #any_value + | LAST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #last + | POSITION LEFT_PAREN substr=valueExpression IN str=valueExpression RIGHT_PAREN #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName DOT ASTERISK #star + | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor + | LEFT_PAREN query RIGHT_PAREN #subqueryExpression + | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause + | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument + (COMMA argument+=functionArgument)*)? RIGHT_PAREN + (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? + (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall + | identifier ARROW expression #lambda + | LEFT_PAREN identifier (COMMA identifier)+ RIGHT_PAREN ARROW expression #lambda + | value=primaryExpression LEFT_BRACKET index=valueExpression RIGHT_BRACKET #subscript + | identifier #columnReference + | base=primaryExpression DOT fieldName=identifier #dereference + | LEFT_PAREN expression RIGHT_PAREN #parenthesizedExpression + | EXTRACT LEFT_PAREN field=identifier FROM source=valueExpression RIGHT_PAREN #extract + | (SUBSTR | SUBSTRING) LEFT_PAREN str=valueExpression (FROM | COMMA) pos=valueExpression + ((FOR | COMMA) len=valueExpression)? RIGHT_PAREN #substring + | TRIM LEFT_PAREN trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression RIGHT_PAREN #trim + | OVERLAY LEFT_PAREN input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay + | name=(PERCENTILE_CONT | PERCENTILE_DISC) LEFT_PAREN percentage=valueExpression RIGHT_PAREN + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN + (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile + ; + +literalType + : DATE + | TIMESTAMP | TIMESTAMP_LTZ | TIMESTAMP_NTZ + | INTERVAL + | BINARY_HEX + | unsupportedType=identifier + ; + +constant + : NULL #nullLiteral + | QUESTION #posParameterLiteral + | COLON identifier #namedParameterLiteral + | interval #intervalLiteral + | literalType stringLit #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | stringLit+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval) + ; + +errorCapturingMultiUnitsInterval + : body=multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=unitInMultiUnits)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=unitInUnitToUnit TO to=unitInUnitToUnit + ; + +intervalValue + : (PLUS | MINUS)? + (INTEGER_VALUE | DECIMAL_VALUE | stringLit) + ; + +unitInMultiUnits + : NANOSECOND | NANOSECONDS | MICROSECOND | MICROSECONDS | MILLISECOND | MILLISECONDS + | SECOND | SECONDS | MINUTE | MINUTES | HOUR | HOURS | DAY | DAYS | WEEK | WEEKS + | MONTH | MONTHS | YEAR | YEARS + ; + +unitInUnitToUnit + : SECOND | MINUTE | HOUR | DAY | MONTH | YEAR + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +type + : BOOLEAN + | TINYINT | BYTE + | SMALLINT | SHORT + | INT | INTEGER + | BIGINT | LONG + | FLOAT | REAL + | DOUBLE + | DATE + | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ + | STRING + | CHARACTER | CHAR + | VARCHAR + | BINARY + | DECIMAL | DEC | NUMERIC + | VOID + | INTERVAL + | ARRAY | STRUCT | MAP + | unsupportedType=identifier + ; + +dataType + : complex=ARRAY LT dataType GT #complexDataType + | complex=MAP LT dataType COMMA dataType GT #complexDataType + | complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType + | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType + | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) + (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType + | type (LEFT_PAREN INTEGER_VALUE + (COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (COMMA qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType colDefinitionDescriptorWithPosition* + ; + +colDefinitionDescriptorWithPosition + : NOT NULL + | defaultExpression + | commentSpec + | colPosition + ; + +defaultExpression + : DEFAULT expression + ; + +variableDefaultExpression + : (DEFAULT | EQ) expression + ; + +colTypeList + : colType (COMMA colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +createOrReplaceTableColTypeList + : createOrReplaceTableColType (COMMA createOrReplaceTableColType)* + ; + +createOrReplaceTableColType + : colName=errorCapturingIdentifier dataType colDefinitionOption* + ; + +colDefinitionOption + : NOT NULL + | defaultExpression + | generationExpression + | commentSpec + ; + +generationExpression + : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN + ; + +complexColTypeList + : complexColType (COMMA complexColType)* + ; + +complexColType + : identifier COLON? dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (COMMA namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | LEFT_PAREN name=errorCapturingIdentifier RIGHT_PAREN #windowRef + | LEFT_PAREN + ( CLUSTER BY partition+=expression (COMMA partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (COMMA partition+=expression)*)? + ((ORDER | SORT) BY sortItem (COMMA sortItem)*)?) + windowFrame? + RIGHT_PAREN #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (COMMA qualifiedName)* + ; + +functionName + : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier (DOT identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + | {double_quoted_identifiers}? DOUBLEQUOTED_STRING + ; + +backQuotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + | SET defaultExpression + | dropDefault=DROP DEFAULT + ; + +stringLit + : STRING_LITERAL + | {!double_quoted_identifiers}? DOUBLEQUOTED_STRING + ; + +comment + : stringLit + | NULL + ; + +version + : INTEGER_VALUE + | stringLit + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ALWAYS + | ANALYZE + | ANTI + | ANY_VALUE + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BIGINT + | BINARY + | BINARY_HEX + | BOOLEAN + | BUCKET + | BUCKETS + | BY + | BYTE + | CACHE + | CASCADE + | CATALOG + | CATALOGS + | CHANGE + | CHAR + | CHARACTER + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DATE + | DATEADD + | DATE_ADD + | DATEDIFF + | DATE_DIFF + | DAY + | DAYS + | DAYOFYEAR + | DBPROPERTIES + | DEC + | DECIMAL + | DECLARE + | DEFAULT + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DOUBLE + | DROP + | ESCAPED + | EXCHANGE + | EXCLUDE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FLOAT + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GENERATED + | GLOBAL + | GROUPING + | HOUR + | HOURS + | IDENTIFIER_KW + | IF + | IGNORE + | IMPORT + | INCLUDE + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INT + | INTEGER + | INTERVAL + | ITEMS + | KEYS + | LAST + | LAZY + | LIKE + | ILIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | LONG + | MACRO + | MAP + | MATCHED + | MERGE + | MICROSECOND + | MICROSECONDS + | MILLISECOND + | MILLISECONDS + | MINUTE + | MINUTES + | MONTH + | MONTHS + | MSCK + | NAME + | NAMESPACE + | NAMESPACES + | NANOSECOND + | NANOSECONDS + | NO + | NULLS + | NUMERIC + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUARTER + | QUERY + | RANGE + | REAL + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPEATABLE + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SCHEMAS + | SECOND + | SECONDS + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHORT + | SHOW + | SINGLE + | SKEWED + | SMALLINT + | SORT + | SORTED + | SOURCE + | START + | STATISTICS + | STORED + | STRATIFY + | STRING + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | SYSTEM_TIME + | SYSTEM_VERSION + | TABLES + | TABLESAMPLE + | TARGET + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TIMEDIFF + | TIMESTAMP + | TIMESTAMP_LTZ + | TIMESTAMP_NTZ + | TIMESTAMPADD + | TIMESTAMPDIFF + | TINYINT + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNPIVOT + | UNSET + | UPDATE + | USE + | VALUES + | VARCHAR + | VAR + | VARIABLE + | VERSION + | VIEW + | VIEWS + | VOID + | WEEK + | WEEKS + | WINDOW + | YEAR + | YEARS + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LATERAL + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ALWAYS + | ANALYZE + | AND + | ANY + | ANY_VALUE + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BIGINT + | BINARY + | BINARY_HEX + | BOOLEAN + | BOTH + | BUCKET + | BUCKETS + | BY + | BYTE + | CACHE + | CASCADE + | CASE + | CAST + | CATALOG + | CATALOGS + | CHANGE + | CHAR + | CHARACTER + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DATE + | DATEADD + | DATE_ADD + | DATEDIFF + | DATE_DIFF + | DAY + | DAYS + | DAYOFYEAR + | DBPROPERTIES + | DEC + | DECIMAL + | DECLARE + | DEFAULT + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DOUBLE + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXCLUDE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FLOAT + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GENERATED + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | HOUR + | HOURS + | IDENTIFIER_KW + | IF + | IGNORE + | IMPORT + | IN + | INCLUDE + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INT + | INTEGER + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LAZY + | LEADING + | LIKE + | LONG + | ILIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | LONG + | MACRO + | MAP + | MATCHED + | MERGE + | MICROSECOND + | MICROSECONDS + | MILLISECOND + | MILLISECONDS + | MINUTE + | MINUTES + | MONTH + | MONTHS + | MSCK + | NAME + | NAMESPACE + | NAMESPACES + | NANOSECOND + | NANOSECONDS + | NO + | NOT + | NULL + | NULLS + | NUMERIC + | OF + | OFFSET + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTILE_CONT + | PERCENTILE_DISC + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUARTER + | QUERY + | RANGE + | REAL + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPEATABLE + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SCHEMAS + | SECOND + | SECONDS + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHORT + | SHOW + | SINGLE + | SKEWED + | SMALLINT + | SOME + | SORT + | SORTED + | SOURCE + | START + | STATISTICS + | STORED + | STRATIFY + | STRING + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | SYSTEM_TIME + | SYSTEM_VERSION + | TABLE + | TABLES + | TABLESAMPLE + | TARGET + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TIMEDIFF + | TIMESTAMP + | TIMESTAMP_LTZ + | TIMESTAMP_NTZ + | TIMESTAMPADD + | TIMESTAMPDIFF + | TINYINT + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNPIVOT + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VARCHAR + | VAR + | VARIABLE + | VERSION + | VIEW + | VIEWS + | VOID + | WEEK + | WEEKS + | WHEN + | WHERE + | WINDOW + | WITH + | WITHIN + | YEAR + | YEARS + | ZONE +//--DEFAULT-NON-RESERVED-END + ; diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index efc23e08b5..a86aa82695 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.asyncquery; +import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -15,6 +16,7 @@ import java.util.Optional; import lombok.AllArgsConstructor; import org.json.JSONObject; +import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; @@ -22,6 +24,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -60,11 +63,15 @@ public CreateAsyncQueryResponse createAsyncQuery( () -> SparkExecutionEngineConfig.toSparkExecutionEngineConfig( sparkExecutionEngineConfigString)); + ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); String jobId = sparkQueryDispatcher.dispatch( - sparkExecutionEngineConfig.getApplicationId(), - createAsyncQueryRequest.getQuery(), - sparkExecutionEngineConfig.getExecutionRoleARN()); + new DispatchQueryRequest( + sparkExecutionEngineConfig.getApplicationId(), + createAsyncQueryRequest.getQuery(), + createAsyncQueryRequest.getLang(), + sparkExecutionEngineConfig.getExecutionRoleARN(), + clusterName.value())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); return new CreateAsyncQueryResponse(jobId); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClient.java new file mode 100644 index 0000000000..8dff8f0ea6 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClient.java @@ -0,0 +1,45 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.sql.spark.client; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; + +/** + * Client Interface for spark Job Submissions. Can have multiple implementations based on the + * underlying spark infrastructure. Currently, we have one for EMRServerless {@link + * EmrServerlessClientImplEMR} + */ +public interface EMRServerlessClient { + + /** + * Start a new job run. + * + * @param startJobRequest startJobRequest + * @return jobId. + */ + String startJobRun(StartJobRequest startJobRequest); + + /** + * Get status of emr serverless job run.. + * + * @param applicationId serverless applicationId + * @param jobId jobId. + * @return {@link GetJobRunResult} + */ + GetJobRunResult getJobRunResult(String applicationId, String jobId); + + /** + * Cancel emr serverless job run. + * + * @param applicationId applicationId. + * @param jobId jobId. + * @return {@link CancelJobRunResult} + */ + CancelJobRunResult cancelJobRun(String applicationId, String jobId); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java similarity index 82% rename from spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java rename to spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java index 2377b2f5da..83e570ece2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImplEMR.java @@ -23,34 +23,31 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -public class EmrServerlessClientImpl implements SparkJobClient { +public class EmrServerlessClientImplEMR implements EMRServerlessClient { private final AWSEMRServerless emrServerless; - private static final Logger logger = LogManager.getLogger(EmrServerlessClientImpl.class); + private static final Logger logger = LogManager.getLogger(EmrServerlessClientImplEMR.class); - public EmrServerlessClientImpl(AWSEMRServerless emrServerless) { + public EmrServerlessClientImplEMR(AWSEMRServerless emrServerless) { this.emrServerless = emrServerless; } @Override - public String startJobRun( - String query, - String jobName, - String applicationId, - String executionRoleArn, - String sparkSubmitParams) { + public String startJobRun(StartJobRequest startJobRequest) { StartJobRunRequest request = new StartJobRunRequest() - .withName(jobName) - .withApplicationId(applicationId) - .withExecutionRoleArn(executionRoleArn) + .withName(startJobRequest.getJobName()) + .withApplicationId(startJobRequest.getApplicationId()) + .withExecutionRoleArn(startJobRequest.getExecutionRoleArn()) + .withTags(startJobRequest.getTags()) .withJobDriver( new JobDriver() .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments(query, SPARK_RESPONSE_BUFFER_INDEX_NAME) - .withSparkSubmitParameters(sparkSubmitParams))); + .withEntryPointArguments( + startJobRequest.getQuery(), SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = AccessController.doPrivileged( (PrivilegedAction) () -> emrServerless.startJobRun(request)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java deleted file mode 100644 index c6b3059c77..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * - */ - -package org.opensearch.sql.spark.client; - -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; - -public interface SparkJobClient { - - String startJobRun( - String query, - String jobName, - String applicationId, - String executionRoleArn, - String sparkSubmitParams); - - GetJobRunResult getJobRunResult(String applicationId, String jobId); - - CancelJobRunResult cancelJobRun(String applicationId, String jobId); -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java new file mode 100644 index 0000000000..94689c7030 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import java.util.Map; +import lombok.Data; + +/** + * This POJO carries all the fields required for emr serverless job submission. Used as model in + * {@link EMRServerlessClient} interface. + */ +@Data +public class StartJobRequest { + private final String query; + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final String sparkSubmitParams; + private final Map tags; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 442838331f..904d199663 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -20,43 +20,49 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import java.net.URI; import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; import lombok.AllArgsConstructor; import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.asyncquery.model.S3GlueSparkSubmitParameters; -import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.utils.SQLQueryUtils; /** This class takes care of understanding query and dispatching job query to emr serverless. */ @AllArgsConstructor public class SparkQueryDispatcher { - private SparkJobClient sparkJobClient; + public static final String INDEX_TAG_KEY = "index"; + public static final String DATASOURCE_TAG_KEY = "datasource"; + public static final String SCHEMA_TAG_KEY = "schema"; + public static final String TABLE_TAG_KEY = "table"; + public static final String CLUSTER_NAME_TAG_KEY = "cluster"; + + private EMRServerlessClient EMRServerlessClient; private DataSourceService dataSourceService; + private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; + private JobExecutionResponseReader jobExecutionResponseReader; - public String dispatch(String applicationId, String query, String executionRoleARN) { - String datasourceName = getDataSourceName(); - try { - return sparkJobClient.startJobRun( - query, - "flint-opensearch-query", - applicationId, - executionRoleARN, - constructSparkParameters(datasourceName)); - } catch (URISyntaxException e) { - throw new IllegalArgumentException( - String.format( - "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); - } + public String dispatch(DispatchQueryRequest dispatchQueryRequest) { + return EMRServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); } // TODO : Fetch from Result Index and then make call to EMR Serverless. public JSONObject getQueryResponse(String applicationId, String queryId) { - GetJobRunResult getJobRunResult = sparkJobClient.getJobRunResult(applicationId, queryId); + GetJobRunResult getJobRunResult = EMRServerlessClient.getJobRunResult(applicationId, queryId); JSONObject result = new JSONObject(); if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId); @@ -66,23 +72,33 @@ public JSONObject getQueryResponse(String applicationId, String queryId) { } public String cancelJob(String applicationId, String jobId) { - CancelJobRunResult cancelJobRunResult = sparkJobClient.cancelJobRun(applicationId, jobId); + CancelJobRunResult cancelJobRunResult = EMRServerlessClient.cancelJobRun(applicationId, jobId); return cancelJobRunResult.getJobRunId(); } - // TODO: Analyze given query - // Extract datasourceName - // Apply Authorizaiton. - private String getDataSourceName() { - return "my_glue"; + private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) { + if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { + if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) + return getStartJobRequestForIndexRequest(dispatchQueryRequest); + else { + return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); + } + } + throw new UnsupportedOperationException( + String.format("UnSupported Lang type:: %s", dispatchQueryRequest.getLangType())); } - // TODO: Analyze given query and get the role arn based on datasource type. private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { - return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + if (DataSourceType.S3GLUE.equals(dataSourceMetadata.getConnector())) { + return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + } + throw new UnsupportedOperationException( + String.format( + "UnSupported datasource type for async queries:: %s", + dataSourceMetadata.getConnector())); } - private String constructSparkParameters(String datasourceName) throws URISyntaxException { + private String constructSparkParameters(String datasourceName) { DataSourceMetadata dataSourceMetadata = dataSourceService.getRawDataSourceMetadata(datasourceName); S3GlueSparkSubmitParameters s3GlueSparkSubmitParameters = new S3GlueSparkSubmitParameters(); @@ -93,7 +109,14 @@ private String constructSparkParameters(String datasourceName) throws URISyntaxE s3GlueSparkSubmitParameters.addParameter( HIVE_METASTORE_GLUE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); String opensearchuri = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.uri"); - URI uri = new URI(opensearchuri); + URI uri; + try { + uri = new URI(opensearchuri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } String auth = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.auth"); String region = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.region"); s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); @@ -106,4 +129,80 @@ private String constructSparkParameters(String datasourceName) throws URISyntaxE "spark.sql.catalog." + datasourceName, FLINT_DELEGATE_CATALOG); return s3GlueSparkSubmitParameters.toString(); } + + private StartJobRequest getStartJobRequestForNonIndexQueries( + DispatchQueryRequest dispatchQueryRequest) { + StartJobRequest startJobRequest; + FullyQualifiedTableName fullyQualifiedTableName = + SQLQueryUtils.extractFullyQualifiedTableName(dispatchQueryRequest.getQuery()); + if (fullyQualifiedTableName.getDatasourceName() == null) { + throw new UnsupportedOperationException("Missing datasource in the query syntax."); + } + dataSourceUserAuthorizationHelper.authorizeDataSource( + this.dataSourceService.getRawDataSourceMetadata( + fullyQualifiedTableName.getDatasourceName())); + String jobName = + dispatchQueryRequest.getClusterName() + + ":" + + fullyQualifiedTableName.getFullyQualifiedName(); + Map tags = + getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName); + startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + tags); + return startJobRequest; + } + + private StartJobRequest getStartJobRequestForIndexRequest( + DispatchQueryRequest dispatchQueryRequest) { + StartJobRequest startJobRequest; + IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + if (fullyQualifiedTableName.getDatasourceName() == null) { + throw new UnsupportedOperationException("Queries without a datasource are not supported"); + } + dataSourceUserAuthorizationHelper.authorizeDataSource( + this.dataSourceService.getRawDataSourceMetadata( + fullyQualifiedTableName.getDatasourceName())); + String jobName = + getJobNameForIndexQuery(dispatchQueryRequest, indexDetails, fullyQualifiedTableName); + Map tags = + getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName); + tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); + startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + tags); + return startJobRequest; + } + + private static Map getDefaultTagsForJobSubmission( + DispatchQueryRequest dispatchQueryRequest, FullyQualifiedTableName fullyQualifiedTableName) { + Map tags = new HashMap<>(); + tags.put(CLUSTER_NAME_TAG_KEY, dispatchQueryRequest.getClusterName()); + tags.put(DATASOURCE_TAG_KEY, fullyQualifiedTableName.getDatasourceName()); + tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); + tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); + return tags; + } + + private static String getJobNameForIndexQuery( + DispatchQueryRequest dispatchQueryRequest, + IndexDetails indexDetails, + FullyQualifiedTableName fullyQualifiedTableName) { + return dispatchQueryRequest.getClusterName() + + ":" + + fullyQualifiedTableName.getFullyQualifiedName() + + "." + + indexDetails.getIndexName(); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java new file mode 100644 index 0000000000..330eb3a03e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import lombok.Data; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class DispatchQueryRequest { + private final String applicationId; + private final String query; + private final LangType langType; + private final String executionRoleARN; + private final String clusterName; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java new file mode 100644 index 0000000000..5a9fe4d31f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import java.util.Arrays; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** Fully Qualified Table Name in the query provided. */ +@Data +@NoArgsConstructor +public class FullyQualifiedTableName { + private String datasourceName; + private String schemaName; + private String tableName; + private String fullyQualifiedName; + + /** + * This constructor also takes care of logic to split the fully qualified name into respective + * pieces. If the name has more than three parts, first part is assigned tp datasource name, + * second is schemaName, third is tableName If there are only two parts, first part is assigned to + * schema name and second to table. If there is only one part it is assigned to table Name. + * + * @param fullyQualifiedName fullyQualifiedName. + */ + public FullyQualifiedTableName(String fullyQualifiedName) { + this.fullyQualifiedName = fullyQualifiedName; + String[] parts = fullyQualifiedName.split("\\."); + if (parts.length >= 3) { + datasourceName = parts[0]; + schemaName = parts[1]; + tableName = String.join(".", Arrays.copyOfRange(parts, 2, parts.length)); + } else if (parts.length == 2) { + schemaName = parts[0]; + tableName = parts[1]; + } else if (parts.length == 1) { + tableName = parts[0]; + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java new file mode 100644 index 0000000000..5067439061 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import lombok.Data; + +/** Index details in an async query. */ +@Data +public class IndexDetails { + private String indexName; + private FullyQualifiedTableName fullyQualifiedTableName; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 1e46ae48d2..c1ad979877 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -17,24 +17,27 @@ public class CreateAsyncQueryRequest { private String query; - private String lang; + private LangType lang; public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; - String lang = null; + LangType lang = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); if (fieldName.equals("query")) { query = parser.textOrNull(); - } else if (fieldName.equals("kind")) { - lang = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + lang = LangType.fromString(parser.textOrNull()); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } + if (lang == null || query == null) { + throw new IllegalArgumentException("lang and query are required fields."); + } return new CreateAsyncQueryRequest(query, lang); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java new file mode 100644 index 0000000000..51fa8d2b13 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +/** Language type accepted in async query apis. */ +public enum LangType { + SQL("sql"), + PPL("ppl"); + private final String text; + + LangType(String text) { + this.text = text; + } + + public String getText() { + return this.text; + } + + /** + * Get LangType from text. + * + * @param text text. + * @return LangType {@link LangType}. + */ + public static LangType fromString(String text) { + for (LangType langType : LangType.values()) { + if (langType.text.equalsIgnoreCase(text)) { + return langType; + } + } + throw new IllegalArgumentException("No LangType with text " + text + " found"); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java new file mode 100644 index 0000000000..2ddc34af5a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import lombok.Getter; +import lombok.experimental.UtilityClass; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +/** + * This util class parses spark sql query and provides util functions to identify indexName, + * tableName and datasourceName. + */ +@UtilityClass +public class SQLQueryUtils { + + // TODO Handle cases where the query has multiple table Names. + public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); + statement.accept(sparkSqlTableNameVisitor); + return sparkSqlTableNameVisitor.getFullyQualifiedTableName(); + } + + public static IndexDetails extractIndexDetails(String sqlQuery) { + FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = + new FlintSparkSqlExtensionsParser( + new CommonTokenStream( + new FlintSparkSqlExtensionsLexer(new CaseInsensitiveCharStream(sqlQuery)))); + flintSparkSqlExtensionsParser.addErrorListener(new SyntaxAnalysisErrorListener()); + FlintSparkSqlExtensionsParser.StatementContext statementContext = + flintSparkSqlExtensionsParser.statement(); + FlintSQLIndexDetailsVisitor flintSQLIndexDetailsVisitor = new FlintSQLIndexDetailsVisitor(); + statementContext.accept(flintSQLIndexDetailsVisitor); + return flintSQLIndexDetailsVisitor.getIndexDetails(); + } + + public static boolean isIndexQuery(String sqlQuery) { + FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = + new FlintSparkSqlExtensionsParser( + new CommonTokenStream( + new FlintSparkSqlExtensionsLexer(new CaseInsensitiveCharStream(sqlQuery)))); + flintSparkSqlExtensionsParser.addErrorListener(new SyntaxAnalysisErrorListener()); + try { + flintSparkSqlExtensionsParser.statement(); + return true; + } catch (SyntaxCheckException syntaxCheckException) { + return false; + } + } + + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { + + @Getter private FullyQualifiedTableName fullyQualifiedTableName; + + public SparkSqlTableNameVisitor() { + this.fullyQualifiedTableName = new FullyQualifiedTableName(); + } + + @Override + public Void visitTableName(SqlBaseParser.TableNameContext ctx) { + fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText()); + return super.visitTableName(ctx); + } + + @Override + public Void visitDropTable(SqlBaseParser.DropTableContext ctx) { + for (ParseTree parseTree : ctx.children) { + if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { + fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + } + } + return super.visitDropTable(ctx); + } + + @Override + public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) { + for (ParseTree parseTree : ctx.children) { + if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { + fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + } + } + return super.visitDescribeRelation(ctx); + } + + // Extract table name for create Table Statement. + @Override + public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { + for (ParseTree parseTree : ctx.children) { + if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { + fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + } + } + return super.visitCreateTableHeader(ctx); + } + } + + public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { + + @Getter private final IndexDetails indexDetails; + + public FlintSQLIndexDetailsVisitor() { + this.indexDetails = new IndexDetails(); + } + + @Override + public Void visitIndexName(FlintSparkSqlExtensionsParser.IndexNameContext ctx) { + indexDetails.setIndexName(ctx.getText()); + return super.visitIndexName(ctx); + } + + @Override + public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { + indexDetails.setFullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); + return super.visitTableName(ctx); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 5e832777fc..1ff2493e6d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.utils.TestUtils.getJson; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -23,13 +24,16 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; @ExtendWith(MockitoExtension.class) public class AsyncQueryExecutorServiceImplTest { @@ -44,25 +48,34 @@ void testCreateAsyncQuery() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", "sql"); + new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", LangType.SQL); when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) .thenReturn( "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); + when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) + .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "arn:aws:iam::270824043731:role/emr-job-execution-role")) + new DispatchQueryRequest( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + LangType.SQL, + "arn:aws:iam::270824043731:role/emr-job-execution-role", + TEST_CLUSTER_NAME))) .thenReturn(EMR_JOB_ID); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) .storeJobMetadata(new AsyncQueryJobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p")); verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); + verify(settings, times(1)).getSettingValue(Settings.Key.CLUSTER_NAME); verify(sparkQueryDispatcher, times(1)) .dispatch( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "arn:aws:iam::270824043731:role/emr-job-execution-role"); + new DispatchQueryRequest( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + LangType.SQL, + "arn:aws:iam::270824043731:role/emr-job-execution-role", + TEST_CLUSTER_NAME)); Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 925ee73bcd..0765b90534 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -20,6 +20,7 @@ import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.StartJobRunResult; import com.amazonaws.services.emrserverless.model.ValidationException; +import java.util.HashMap; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -35,9 +36,15 @@ void testStartJobRun() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); emrServerlessClient.startJobRun( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS); + new StartJobRequest( + QUERY, + EMRS_JOB_NAME, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + SPARK_SUBMIT_PARAMETERS, + new HashMap<>())); } @Test @@ -47,7 +54,7 @@ void testGetJobRunState() { GetJobRunResult response = new GetJobRunResult(); response.setJobRun(jobRun); when(emrServerless.getJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } @@ -55,7 +62,7 @@ void testGetJobRunState() { void testCancelJobRun() { when(emrServerless.cancelJobRun(any())) .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); - EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); @@ -64,7 +71,7 @@ void testCancelJobRun() { @Test void testCancelJobRunWithValidationException() { doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + EmrServerlessClientImplEMR emrServerlessClient = new EmrServerlessClientImplEMR(emrServerless); IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index e455e6a049..abae0377a2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -15,4 +15,5 @@ public class TestConstants { public static final String EMRS_DATASOURCE_ROLE = "datasource_role"; public static final String EMRS_JOB_NAME = "job_name"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; + public static final String TEST_CLUSTER_NAME = "TEST_CLUSTER"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2000eeefed..d83505fde0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -12,7 +13,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -31,60 +32,286 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.LangType; @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { - @Mock private SparkJobClient sparkJobClient; + @Mock private EMRServerlessClient EMRServerlessClient; @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; + @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Test - void testDispatch() { + void testDispatchSelectQuery() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); - when(sparkJobClient.startJobRun( - QUERY, - "flint-opensearch-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString())) + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("table", "http_logs"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("schema", "default"); + String query = "select * from my_glue.default.http_logs"; + when(EMRServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) .thenReturn(EMR_JOB_ID); - when(dataSourceService.getRawDataSourceMetadata("my_glue")) - .thenReturn(constructMyGlueDataSourceMetadata()); - String jobId = sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE); - verify(sparkJobClient, times(1)) + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(EMRServerlessClient, times(1)) + .startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + + @Test + void testDispatchIndexQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("table", "http_logs"); + tags.put("index", "elb_and_requestUri"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("schema", "default"); + String query = + "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + when(EMRServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs.elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(EMRServerlessClient, times(1)) .startJobRun( - QUERY, - "flint-opensearch-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString()); + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs.elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); Assertions.assertEquals(EMR_JOB_ID, jobId); } + @Test + void testDispatchWithPPLQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "select * from my_glue.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.PPL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); + Assertions.assertEquals( + "UnSupported Lang type:: PPL", unsupportedOperationException.getMessage()); + verifyNoInteractions(EMRServerlessClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchQueryWithoutATableName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "show tables"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); + Assertions.assertEquals( + "Missing datasource in the query syntax.", unsupportedOperationException.getMessage()); + verifyNoInteractions(EMRServerlessClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchQueryWithoutADataSourceName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "select * from default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); + Assertions.assertEquals( + "Missing datasource in the query syntax.", unsupportedOperationException.getMessage()); + verifyNoInteractions(EMRServerlessClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchIndexQueryWithoutADatasourceName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = + "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); + Assertions.assertEquals( + "Queries without a datasource are not supported", + unsupportedOperationException.getMessage()); + verifyNoInteractions(EMRServerlessClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + @Test void testDispatchWithWrongURI() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); when(dataSourceService.getRawDataSourceMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); + String query = "select * from my_glue.default.http_logs"; IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE)); + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", illegalArgumentException.getMessage()); } + @Test + void testDispatchWithUnSupportedDataSourceType() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + when(dataSourceService.getRawDataSourceMetadata("my_prometheus")) + .thenReturn(constructPrometheusDataSourceType()); + String query = "select * from my_prometheus.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME))); + Assertions.assertEquals( + "UnSupported datasource type for async queries:: PROMETHEUS", + unsupportedOperationException.getMessage()); + } + @Test void testCancelJob() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); - when(sparkJobClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + when(EMRServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) @@ -96,8 +323,12 @@ void testCancelJob() { @Test void testGetQueryResponse() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); - when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); Assertions.assertEquals("PENDING", result.get("status")); @@ -107,15 +338,19 @@ void testGetQueryResponse() { @Test void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); - when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + new SparkQueryDispatcher( + EMRServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); queryResult.put("data", "result"); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) .thenReturn(queryResult); JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); - verify(sparkJobClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(EMRServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); Assertions.assertEquals("result", result.get("data")); @@ -185,4 +420,13 @@ private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { dataSourceMetadata.setProperties(properties); return dataSourceMetadata; } + + private DataSourceMetadata constructPrometheusDataSourceType() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_prometheus"); + dataSourceMetadata.setConnector(DataSourceType.PROMETHEUS); + Map properties = new HashMap<>(); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 6596a9e820..ef49d29829 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -27,6 +27,7 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -56,7 +57,7 @@ public void setUp() { @Test public void testDoExecute() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) @@ -72,7 +73,7 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); doThrow(new RuntimeException("Error")) diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java new file mode 100644 index 0000000000..91b5befe88 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +@ExtendWith(MockitoExtension.class) +public class SQLQueryUtilsTest { + + @Test + void testExtractionOfTableNameFromSQLQueries() { + String sqlQuery = "select * from my_glue.default.http_logs"; + FullyQualifiedTableName fullyQualifiedTableName = + SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from my_glue.db.http_logs"; + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("db", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from my_glue.http_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getSchemaName()); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from http_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "DROP TABLE myS3.default.alb_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "DESCRIBE TABLE myS3.default.alb_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = + "CREATE EXTERNAL TABLE\n" + + "myS3.default.alb_logs\n" + + "[ PARTITIONED BY (col_name [, … ] ) ]\n" + + "[ ROW FORMAT DELIMITED row_format ]\n" + + "STORED AS file_format\n" + + "LOCATION { 's3://bucket/folder/' }"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + } + + @Test + void testErrorScenarios() { + String sqlQuery = "SHOW tables"; + FullyQualifiedTableName fullyQualifiedTableName = + SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertNotNull(fullyQualifiedTableName); + Assertions.assertNull(fullyQualifiedTableName.getFullyQualifiedName()); + Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); + Assertions.assertNull(fullyQualifiedTableName.getTableName()); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + + sqlQuery = "DESCRIBE TABLE FROM myS3.default.alb_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("FROM", fullyQualifiedTableName.getFullyQualifiedName()); + Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("FROM", fullyQualifiedTableName.getTableName()); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + } + + @Test + void testExtractionFromFlintIndexQueries() { + String createCoveredIndexQuery = + "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); + IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("elb_and_requestUri", indexDetails.getIndexName()); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + } +}