diff --git a/docs/user/ppl/admin/connectors/spark_connector.rst b/docs/user/ppl/admin/connectors/spark_connector.rst new file mode 100644 index 0000000000..59a52998bc --- /dev/null +++ b/docs/user/ppl/admin/connectors/spark_connector.rst @@ -0,0 +1,92 @@ +.. highlight:: sh + +==================== +Spark Connector +==================== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +This page covers spark connector properties for dataSource configuration +and the nuances associated with spark connector. + + +Spark Connector Properties in DataSource Configuration +======================================================== +Spark Connector Properties. + +* ``spark.connector`` [Required]. + * This parameters provides the spark client information for connection. +* ``spark.sql.application`` [Optional]. + * This parameters provides the spark sql application jar. Default value is ``s3://spark-datasource/sql-job.jar``. +* ``emr.cluster`` [Required]. + * This parameters provides the emr cluster id information. +* ``emr.auth.type`` [Required] + * This parameters provides the authentication type information. + * Spark emr connector currently supports ``awssigv4`` authentication mechanism and following parameters are required. + * ``emr.auth.region``, ``emr.auth.access_key`` and ``emr.auth.secret_key`` +* ``spark.datasource.flint.*`` [Optional] + * This parameters provides the Opensearch domain host information for flint integration. + * ``spark.datasource.flint.integration`` [Optional] + * Default value for integration jar is ``s3://spark-datasource/flint-spark-integration-assembly-0.3.0-SNAPSHOT.jar``. + * ``spark.datasource.flint.host`` [Optional] + * Default value for host is ``localhost``. + * ``spark.datasource.flint.port`` [Optional] + * Default value for port is ``9200``. + * ``spark.datasource.flint.scheme`` [Optional] + * Default value for scheme is ``http``. + * ``spark.datasource.flint.auth`` [Optional] + * Default value for auth is ``false``. + * ``spark.datasource.flint.region`` [Optional] + * Default value for auth is ``us-west-2``. + +Example spark dataSource configuration +======================================== + +AWSSigV4 Auth:: + + [{ + "name" : "my_spark", + "connector": "spark", + "properties" : { + "spark.connector": "emr", + "emr.cluster" : "{{clusterId}}", + "emr.auth.type" : "awssigv4", + "emr.auth.region" : "us-east-1", + "emr.auth.access_key" : "{{accessKey}}" + "emr.auth.secret_key" : "{{secretKey}}" + "spark.datasource.flint.host" : "{{opensearchHost}}", + "spark.datasource.flint.port" : "{{opensearchPort}}", + "spark.datasource.flint.scheme" : "{{opensearchScheme}}", + "spark.datasource.flint.auth" : "{{opensearchAuth}}", + "spark.datasource.flint.region" : "{{opensearchRegion}}", + } + }] + + +Spark SQL Support +================== + +`sql` Function +---------------------------- +Spark connector offers `sql` function. This function can be used to run spark sql query. +The function takes spark sql query as input. Argument should be either passed by name or positionArguments should be either passed by name or position. +`source=my_spark.sql('select 1')` +or +`source=my_spark.sql(query='select 1')` +Example:: + + > source=my_spark.sql('select 1') + +---+ + | 1 | + |---+ + | 1 | + +---+ + diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index cfce8e9cfe..a9eb38a2c2 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -83,6 +83,7 @@ import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; +import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; @@ -282,6 +283,7 @@ private DataSourceServiceImpl createDataSourceService() { new OpenSearchDataSourceFactory( new OpenSearchNodeClient(this.client), pluginSettings)) .add(new PrometheusStorageFactory(pluginSettings)) + .add(new SparkStorageFactory(this.client, pluginSettings)) .add(new GlueDataSourceFactory(pluginSettings)) .build(), dataSourceMetadataStorage, diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java new file mode 100644 index 0000000000..87f35bbc1e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.datasource.model.DataSourceMetadata.DEFAULT_RESULT_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.ActionOnFailure; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsRequest; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepRequest; +import com.amazonaws.services.elasticmapreduce.model.HadoopJarStepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import lombok.SneakyThrows; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +public class EmrClientImpl implements SparkClient { + private final AmazonElasticMapReduce emr; + private final String emrCluster; + private final FlintHelper flint; + private final String sparkApplicationJar; + private static final Logger logger = LogManager.getLogger(EmrClientImpl.class); + private SparkResponse sparkResponse; + + /** + * Constructor for EMR Client Implementation. + * + * @param emr EMR helper + * @param flint Opensearch args for flint integration jar + * @param sparkResponse Response object to help with retrieving results from Opensearch index + */ + public EmrClientImpl( + AmazonElasticMapReduce emr, + String emrCluster, + FlintHelper flint, + SparkResponse sparkResponse, + String sparkApplicationJar) { + this.emr = emr; + this.emrCluster = emrCluster; + this.flint = flint; + this.sparkResponse = sparkResponse; + this.sparkApplicationJar = + sparkApplicationJar == null ? SPARK_SQL_APPLICATION_JAR : sparkApplicationJar; + } + + @Override + public JSONObject sql(String query) throws IOException { + runEmrApplication(query); + return sparkResponse.getResultFromOpensearchIndex(); + } + + @VisibleForTesting + void runEmrApplication(String query) { + + HadoopJarStepConfig stepConfig = + new HadoopJarStepConfig() + .withJar("command-runner.jar") + .withArgs( + "spark-submit", + "--class", + "org.opensearch.sql.SQLJob", + "--jars", + flint.getFlintIntegrationJar(), + sparkApplicationJar, + query, + DEFAULT_RESULT_INDEX, + flint.getFlintHost(), + flint.getFlintPort(), + flint.getFlintScheme(), + flint.getFlintAuth(), + flint.getFlintRegion()); + + StepConfig emrstep = + new StepConfig() + .withName("Spark Application") + .withActionOnFailure(ActionOnFailure.CONTINUE) + .withHadoopJarStep(stepConfig); + + AddJobFlowStepsRequest request = + new AddJobFlowStepsRequest().withJobFlowId(emrCluster).withSteps(emrstep); + + AddJobFlowStepsResult result = emr.addJobFlowSteps(request); + logger.info("EMR step ID: " + result.getStepIds()); + + String stepId = result.getStepIds().get(0); + DescribeStepRequest stepRequest = + new DescribeStepRequest().withClusterId(emrCluster).withStepId(stepId); + + waitForStepExecution(stepRequest); + sparkResponse.setValue(stepId); + } + + @SneakyThrows + private void waitForStepExecution(DescribeStepRequest stepRequest) { + // Wait for the step to complete + boolean completed = false; + while (!completed) { + // Get the step status + StepStatus statusDetail = emr.describeStep(stepRequest).getStep().getStatus(); + // Check if the step has completed + if (statusDetail.getState().equals("COMPLETED")) { + completed = true; + logger.info("EMR step completed successfully."); + } else if (statusDetail.getState().equals("FAILED") + || statusDetail.getState().equals("CANCELLED")) { + logger.error("EMR step failed or cancelled."); + throw new RuntimeException("Spark SQL application failed."); + } else { + // Sleep for some time before checking the status again + Thread.sleep(2500); + } + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java new file mode 100644 index 0000000000..b38f04680b --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import java.io.IOException; +import org.json.JSONObject; + +/** Interface class for Spark Client. */ +public interface SparkClient { + /** + * This method executes spark sql query. + * + * @param query spark sql query + * @return spark query response + */ + JSONObject sql(String query) throws IOException; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java new file mode 100644 index 0000000000..914aa80085 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.implementation; + +import static org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver.QUERY; + +import java.util.List; +import java.util.stream.Collectors; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; +import org.opensearch.sql.storage.Table; + +/** Spark SQL function implementation. */ +public class SparkSqlFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + private final FunctionName functionName; + private final List arguments; + private final SparkClient sparkClient; + + /** + * Constructor for spark sql function. + * + * @param functionName name of the function + * @param arguments a list of expressions + * @param sparkClient spark client + */ + public SparkSqlFunctionImplementation( + FunctionName functionName, List arguments, SparkClient sparkClient) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.sparkClient = sparkClient; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format( + "Spark defined function [%s] is only " + + "supported in SOURCE clause with spark connector catalog", + functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = + arguments.stream() + .map( + arg -> + String.format( + "%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + return new SparkTable(sparkClient, buildQueryFromSqlFunction(arguments)); + } + + /** + * This method builds a spark query request. + * + * @param arguments spark sql function arguments + * @return spark query request + */ + private SparkQueryRequest buildQueryFromSqlFunction(List arguments) { + + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + arguments.forEach( + arg -> { + String argName = ((NamedArgumentExpression) arg).getArgName(); + Expression argValue = ((NamedArgumentExpression) arg).getValue(); + ExprValue literalValue = argValue.valueOf(); + if (argName.equals(QUERY)) { + sparkQueryRequest.setSql((String) literalValue.value()); + } else { + throw new ExpressionEvaluationException( + String.format("Invalid Function Argument:%s", argName)); + } + }); + return sparkQueryRequest; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java new file mode 100644 index 0000000000..a4f2a6c0fe --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.resolver; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.ArrayList; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; + +/** Function resolver for sql function of spark connector. */ +@RequiredArgsConstructor +public class SparkSqlTableFunctionResolver implements FunctionResolver { + private final SparkClient sparkClient; + + public static final String SQL = "sql"; + public static final String QUERY = "query"; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(SQL); + FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); + final List argumentNames = List.of(QUERY); + + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> { + Boolean argumentsPassedByName = + arguments.stream() + .noneMatch( + arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + Boolean argumentsPassedByPosition = + arguments.stream() + .allMatch( + arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + if (!(argumentsPassedByName || argumentsPassedByPosition)) { + throw new SemanticCheckException( + "Arguments should be either passed by name or position"); + } + + if (arguments.size() != argumentNames.size()) { + throw new SemanticCheckException( + String.format( + "Missing arguments:[%s]", + String.join( + ",", argumentNames.subList(arguments.size(), argumentNames.size())))); + } + + if (argumentsPassedByPosition) { + List namedArguments = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { + namedArguments.add( + new NamedArgumentExpression( + argumentNames.get(i), + ((NamedArgumentExpression) arguments.get(i)).getValue())); + } + return new SparkSqlFunctionImplementation(functionName, namedArguments, sparkClient); + } + return new SparkSqlFunctionImplementation(functionName, arguments, sparkClient); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(SQL); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java new file mode 100644 index 0000000000..aea8f72f36 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import lombok.AllArgsConstructor; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** TableScanBuilder for sql function of spark connector. */ +@AllArgsConstructor +public class SparkSqlFunctionTableScanBuilder extends TableScanBuilder { + + private final SparkClient sparkClient; + + private final SparkQueryRequest sparkQueryRequest; + + @Override + public TableScanOperator build() { + return new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + } + + @Override + public boolean pushDownProject(LogicalProject project) { + return true; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java new file mode 100644 index 0000000000..a2e44affd5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Locale; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.functions.response.SparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** This a table scan operator to handle sql table function. */ +@RequiredArgsConstructor +public class SparkSqlFunctionTableScanOperator extends TableScanOperator { + private final SparkClient sparkClient; + private final SparkQueryRequest request; + private SparkSqlFunctionResponseHandle sparkResponseHandle; + private static final Logger LOG = LogManager.getLogger(); + + @Override + public void open() { + super.open(); + this.sparkResponseHandle = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + JSONObject responseObject = sparkClient.sql(request.getSql()); + return new DefaultSparkSqlFunctionResponseHandle(responseObject); + } catch (IOException e) { + LOG.error(e.getMessage()); + throw new RuntimeException( + String.format("Error fetching data from spark server: %s", e.getMessage())); + } + }); + } + + @Override + public boolean hasNext() { + return this.sparkResponseHandle.hasNext(); + } + + @Override + public ExprValue next() { + return this.sparkResponseHandle.next(); + } + + @Override + public String explain() { + return String.format(Locale.ROOT, "sql(%s)", request.getSql()); + } + + @Override + public ExecutionEngine.Schema schema() { + return this.sparkResponseHandle.schema(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java new file mode 100644 index 0000000000..395e1685a6 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** Spark scan operator. */ +@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) +@ToString(onlyExplicitlyIncluded = true) +public class SparkScan extends TableScanOperator { + + private final SparkClient sparkClient; + + @EqualsAndHashCode.Include @Getter @Setter @ToString.Include private SparkQueryRequest request; + + /** + * Constructor. + * + * @param sparkClient sparkClient. + */ + public SparkScan(SparkClient sparkClient) { + this.sparkClient = sparkClient; + this.request = new SparkQueryRequest(); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public String explain() { + return getRequest().toString(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java new file mode 100644 index 0000000000..84c9c05e79 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.Collection; +import java.util.Collections; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +/** Spark storage engine implementation. */ +@RequiredArgsConstructor +public class SparkStorageEngine implements StorageEngine { + private final SparkClient sparkClient; + + @Override + public Collection getFunctions() { + return Collections.singletonList(new SparkSqlTableFunctionResolver(sparkClient)); + } + + @Override + public Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName) { + throw new RuntimeException("Unable to get table from storage engine."); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java new file mode 100644 index 0000000000..467bacbaea --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; +import java.security.AccessController; +import java.security.InvalidParameterException; +import java.security.PrivilegedAction; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.client.EmrClientImpl; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.sql.storage.StorageEngine; + +/** Storage factory implementation for spark connector. */ +@RequiredArgsConstructor +public class SparkStorageFactory implements DataSourceFactory { + private final Client client; + private final Settings settings; + + // Spark datasource configuration properties + public static final String CONNECTOR_TYPE = "spark.connector"; + public static final String SPARK_SQL_APPLICATION = "spark.sql.application"; + + // EMR configuration properties + public static final String EMR_CLUSTER = "emr.cluster"; + public static final String EMR_AUTH_TYPE = "emr.auth.type"; + public static final String EMR_REGION = "emr.auth.region"; + public static final String EMR_ROLE_ARN = "emr.auth.role_arn"; + public static final String EMR_ACCESS_KEY = "emr.auth.access_key"; + public static final String EMR_SECRET_KEY = "emr.auth.secret_key"; + + // Flint integration jar configuration properties + public static final String FLINT_INTEGRATION = "spark.datasource.flint.integration"; + public static final String FLINT_HOST = "spark.datasource.flint.host"; + public static final String FLINT_PORT = "spark.datasource.flint.port"; + public static final String FLINT_SCHEME = "spark.datasource.flint.scheme"; + public static final String FLINT_AUTH = "spark.datasource.flint.auth"; + public static final String FLINT_REGION = "spark.datasource.flint.region"; + + @Override + public DataSourceType getDataSourceType() { + return DataSourceType.SPARK; + } + + @Override + public DataSource createDataSource(DataSourceMetadata metadata) { + return new DataSource( + metadata.getName(), DataSourceType.SPARK, getStorageEngine(metadata.getProperties())); + } + + /** + * This function gets spark storage engine. + * + * @param requiredConfig spark config options + * @return spark storage engine object + */ + StorageEngine getStorageEngine(Map requiredConfig) { + SparkClient sparkClient; + if (requiredConfig.get(CONNECTOR_TYPE).equals(EMR)) { + sparkClient = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + validateEMRConfigProperties(requiredConfig); + return new EmrClientImpl( + getEMRClient( + requiredConfig.get(EMR_ACCESS_KEY), + requiredConfig.get(EMR_SECRET_KEY), + requiredConfig.get(EMR_REGION)), + requiredConfig.get(EMR_CLUSTER), + new FlintHelper( + requiredConfig.get(FLINT_INTEGRATION), + requiredConfig.get(FLINT_HOST), + requiredConfig.get(FLINT_PORT), + requiredConfig.get(FLINT_SCHEME), + requiredConfig.get(FLINT_AUTH), + requiredConfig.get(FLINT_REGION)), + new SparkResponse(client, null, STEP_ID_FIELD), + requiredConfig.get(SPARK_SQL_APPLICATION)); + }); + } else { + throw new InvalidParameterException("Spark connector type is invalid."); + } + return new SparkStorageEngine(sparkClient); + } + + private void validateEMRConfigProperties(Map dataSourceMetadataConfig) + throws IllegalArgumentException { + if (dataSourceMetadataConfig.get(EMR_CLUSTER) == null + || dataSourceMetadataConfig.get(EMR_AUTH_TYPE) == null) { + throw new IllegalArgumentException("EMR config properties are missing."); + } else if (dataSourceMetadataConfig + .get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName()) + && (dataSourceMetadataConfig.get(EMR_ACCESS_KEY) == null + || dataSourceMetadataConfig.get(EMR_SECRET_KEY) == null)) { + throw new IllegalArgumentException("EMR auth keys are missing."); + } else if (!dataSourceMetadataConfig + .get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName())) { + throw new IllegalArgumentException("Invalid auth type."); + } + } + + private AmazonElasticMapReduce getEMRClient( + String emrAccessKey, String emrSecretKey, String emrRegion) { + return AmazonElasticMapReduceClientBuilder.standard() + .withCredentials( + new AWSStaticCredentialsProvider(new BasicAWSCredentials(emrAccessKey, emrSecretKey))) + .withRegion(emrRegion) + .build(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java new file mode 100644 index 0000000000..731c3df672 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.HashMap; +import java.util.Map; +import lombok.Getter; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** Spark table implementation. This can be constructed from SparkQueryRequest. */ +public class SparkTable implements Table { + + private final SparkClient sparkClient; + + @Getter private final SparkQueryRequest sparkQueryRequest; + + /** Constructor for entire Sql Request. */ + public SparkTable(SparkClient sparkService, SparkQueryRequest sparkQueryRequest) { + this.sparkClient = sparkService; + this.sparkQueryRequest = sparkQueryRequest; + } + + @Override + public boolean exists() { + throw new UnsupportedOperationException( + "Exists operation is not supported in spark datasource"); + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException( + "Create operation is not supported in spark datasource"); + } + + @Override + public Map getFieldTypes() { + return new HashMap<>(); + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + SparkScan metricScan = new SparkScan(sparkClient); + metricScan.setRequest(sparkQueryRequest); + return plan.accept(new DefaultImplementor(), metricScan); + } + + @Override + public TableScanBuilder createScanBuilder() { + return new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java new file mode 100644 index 0000000000..93dc0d6bc8 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -0,0 +1,158 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepResult; +import com.amazonaws.services.elasticmapreduce.model.Step; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +@ExtendWith(MockitoExtension.class) +public class EmrClientImplTest { + + @Mock private AmazonElasticMapReduce emr; + @Mock private FlintHelper flint; + @Mock private SparkResponse sparkResponse; + + @Test + @SneakyThrows + void testRunEmrApplication() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("COMPLETED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testRunEmrApplicationFailed() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("FAILED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationCancelled() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("CANCELLED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationRunnning() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())) + .thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testSql() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())) + .thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + when(sparkResponse.getResultFromOpensearchIndex()) + .thenReturn(new JSONObject(getJson("select_query_response.json"))); + + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.sql(QUERY); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java b/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java index 3b1ea14d40..e58f240f5c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java @@ -11,30 +11,18 @@ import org.opensearch.sql.spark.data.type.SparkDataType; class SparkExprValueTest { - private final SparkDataType sparkDataType = new SparkDataType("char"); - @Test - public void getters() { - SparkExprValue sparkExprValue = new SparkExprValue(sparkDataType, "str"); - - assertEquals(sparkDataType, sparkExprValue.type()); - assertEquals("str", sparkExprValue.value()); + public void type() { + assertEquals( + new SparkDataType("char"), new SparkExprValue(new SparkDataType("char"), "str").type()); } @Test public void unsupportedCompare() { - SparkExprValue sparkExprValue = new SparkExprValue(sparkDataType, "str"); - - assertThrows(UnsupportedOperationException.class, () -> sparkExprValue.compare(sparkExprValue)); - } - - @Test - public void testEquals() { - SparkExprValue sparkExprValue1 = new SparkExprValue(sparkDataType, "str"); - SparkExprValue sparkExprValue2 = new SparkExprValue(sparkDataType, "str"); - SparkExprValue sparkExprValue3 = new SparkExprValue(sparkDataType, "other"); + SparkDataType type = new SparkDataType("char"); - assertTrue(sparkExprValue1.equal(sparkExprValue2)); - assertFalse(sparkExprValue1.equal(sparkExprValue3)); + assertThrows( + UnsupportedOperationException.class, + () -> new SparkExprValue(type, "str").compare(new SparkExprValue(type, "str"))); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java new file mode 100644 index 0000000000..120747e0d3 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionImplementationTest { + @Mock private SparkClient client; + + @Test + void testValueOfAndTypeToString() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList = + List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + UnsupportedOperationException exception = + assertThrows( + UnsupportedOperationException.class, () -> sparkSqlFunctionImplementation.valueOf()); + assertEquals( + "Spark defined function [sql] is only " + + "supported in SOURCE clause with spark connector catalog", + exception.getMessage()); + assertEquals("sql(query=\"select 1\")", sparkSqlFunctionImplementation.toString()); + assertEquals(ExprCoreType.STRUCT, sparkSqlFunctionImplementation.type()); + } + + @Test + void testApplyArguments() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList = + List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + SparkTable sparkTable = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testApplyArgumentsException() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList = + List.of( + DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument("tmp", DSL.literal(12345))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> sparkSqlFunctionImplementation.applyArguments()); + assertEquals("Invalid Function Argument:tmp", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java new file mode 100644 index 0000000000..212056eb15 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +public class SparkSqlFunctionTableScanBuilderTest { + @Mock private SparkClient sparkClient; + + @Mock private LogicalProject logicalProject; + + @Test + void testBuild() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = + new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + TableScanOperator sqlFunctionTableScanOperator = sparkSqlFunctionTableScanBuilder.build(); + Assertions.assertTrue( + sqlFunctionTableScanOperator instanceof SparkSqlFunctionTableScanOperator); + } + + @Test + void testPushProject() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = + new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + Assertions.assertTrue(sparkSqlFunctionTableScanBuilder.pushDownProject(logicalProject)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java new file mode 100644 index 0000000000..d44e3d271a --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -0,0 +1,292 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import lombok.SneakyThrows; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.data.type.SparkDataType; +import org.opensearch.sql.spark.data.value.SparkExprValue; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionTableScanOperatorTest { + + @Mock private SparkClient sparkClient; + + @Test + @SneakyThrows + void testEmptyQueryWithException() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenThrow(new IOException("Error Message")); + RuntimeException runtimeException = + assertThrows(RuntimeException.class, sparkSqlFunctionTableScanOperator::open); + assertEquals( + "Error fetching data from spark server: Error Message", runtimeException.getMessage()); + } + + @Test + @SneakyThrows + void testClose() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + sparkSqlFunctionTableScanOperator.close(); + } + + @Test + @SneakyThrows + void testExplain() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + Assertions.assertEquals("sql(select 1)", sparkSqlFunctionTableScanOperator.explain()); + } + + @Test + @SneakyThrows + void testQueryResponseIterator() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("1", new ExprIntegerValue(1)); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseAllTypes() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("all_data_type.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("boolean", ExprBooleanValue.of(true)); + put("long", new ExprLongValue(922337203)); + put("integer", new ExprIntegerValue(2147483647)); + put("short", new ExprShortValue(32767)); + put("byte", new ExprByteValue(127)); + put("double", new ExprDoubleValue(9223372036854.775807)); + put("float", new ExprFloatValue(21474.83647)); + put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); + put("date", new ExprTimestampValue("2023-07-01 10:31:30")); + put("string", new ExprStringValue("ABC")); + put("char", new SparkExprValue(new SparkDataType("char"), "A")); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseSparkDataType() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("spark_data_type.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put( + "struct_column", + new SparkExprValue( + new SparkDataType("struct"), + new JSONObject("{\"struct_value\":\"value\"}}").toMap())); + put( + "array_column", + new SparkExprValue( + new SparkDataType("array"), new JSONArray("[1,2]").toList())); + } + }), + sparkSqlFunctionTableScanOperator.next()); + } + + @Test + @SneakyThrows + void testQuerySchema() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + ArrayList columns = new ArrayList<>(); + columns.add(new ExecutionEngine.Schema.Column("1", "1", ExprCoreType.INTEGER)); + ExecutionEngine.Schema expectedSchema = new ExecutionEngine.Schema(columns); + assertEquals(expectedSchema, sparkSqlFunctionTableScanOperator.schema()); + } + + /** https://github.com/opensearch-project/sql/issues/2210. */ + @Test + @SneakyThrows + void issue2210() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("issue2210.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("col_name", stringValue("day")); + put("data_type", stringValue("int")); + put("comment", nullValue()); + } + }), + sparkSqlFunctionTableScanOperator.next()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("col_name", stringValue("# Partition Information")); + put("data_type", stringValue("")); + put("comment", stringValue("")); + } + }), + sparkSqlFunctionTableScanOperator.next()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("col_name", stringValue("# col_name")); + put("data_type", stringValue("data_type")); + put("comment", stringValue("comment")); + } + }), + sparkSqlFunctionTableScanOperator.next()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("col_name", stringValue("day")); + put("data_type", stringValue("int")); + put("comment", nullValue()); + } + }), + sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + public void issue2367MissingFields() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn( + new JSONObject( + "{\n" + + " \"data\": {\n" + + " \"result\": [\n" + + " \"{}\",\n" + + " \"{'srcPort':20641}\"\n" + + " ],\n" + + " \"schema\": [\n" + + " \"{'column_name':'srcPort','data_type':'long'}\"\n" + + " ]\n" + + " }\n" + + "}")); + sparkSqlFunctionTableScanOperator.open(); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("srcPort", ExprNullValue.of()); + } + }), + sparkSqlFunctionTableScanOperator.next()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("srcPort", new ExprLongValue(20641L)); + } + }), + sparkSqlFunctionTableScanOperator.next()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java new file mode 100644 index 0000000000..a828ac76c4 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlTableFunctionResolverTest { + @Mock private SparkClient client; + + @Mock private FunctionProperties functionProperties; + + @Test + void testResolve() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation = + (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testArgumentsPassedByPosition() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation = + (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testMixedArgumentTypes() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions = + List.of( + DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument(null, DSL.literal(12345))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Arguments should be either passed by name or position", exception.getMessage()); + } + + @Test + void testWrongArgumentsSizeWhenPassedByName() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions = List.of(); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Missing arguments:[query]", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandleTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandleTest.java deleted file mode 100644 index 3467eb8781..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandleTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.functions.response; - -import static org.junit.jupiter.api.Assertions.*; - -import java.net.URL; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.List; -import java.util.Map; -import org.json.JSONObject; -import org.junit.jupiter.api.Test; -import org.opensearch.sql.data.model.ExprBooleanValue; -import org.opensearch.sql.data.model.ExprByteValue; -import org.opensearch.sql.data.model.ExprDateValue; -import org.opensearch.sql.data.model.ExprDoubleValue; -import org.opensearch.sql.data.model.ExprFloatValue; -import org.opensearch.sql.data.model.ExprIntegerValue; -import org.opensearch.sql.data.model.ExprLongValue; -import org.opensearch.sql.data.model.ExprShortValue; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.executor.ExecutionEngine; -import org.opensearch.sql.executor.ExecutionEngine.Schema.Column; - -class DefaultSparkSqlFunctionResponseHandleTest { - - @Test - public void testConstruct() throws Exception { - DefaultSparkSqlFunctionResponseHandle handle = - new DefaultSparkSqlFunctionResponseHandle(readJson()); - - assertTrue(handle.hasNext()); - ExprValue value = handle.next(); - Map row = value.tupleValue(); - assertEquals(ExprBooleanValue.of(true), row.get("col1")); - assertEquals(new ExprLongValue(2), row.get("col2")); - assertEquals(new ExprIntegerValue(3), row.get("col3")); - assertEquals(new ExprShortValue(4), row.get("col4")); - assertEquals(new ExprByteValue(5), row.get("col5")); - assertEquals(new ExprDoubleValue(6.1), row.get("col6")); - assertEquals(new ExprFloatValue(7.1), row.get("col7")); - assertEquals(new ExprStringValue("2024-01-02 03:04:05.1234"), row.get("col8")); - assertEquals(new ExprDateValue("2024-01-03 04:05:06.1234"), row.get("col9")); - assertEquals(new ExprStringValue("some string"), row.get("col10")); - - ExecutionEngine.Schema schema = handle.schema(); - List columns = schema.getColumns(); - assertEquals("col1", columns.get(0).getName()); - } - - private JSONObject readJson() throws Exception { - final URL url = - DefaultSparkSqlFunctionResponseHandle.class.getResource( - "/spark_execution_result_test.json"); - return new JSONObject(Files.readString(Paths.get(url.toURI()))); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/helper/FlintHelperTest.java b/spark/src/test/java/org/opensearch/sql/spark/helper/FlintHelperTest.java deleted file mode 100644 index 009119a016..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/helper/FlintHelperTest.java +++ /dev/null @@ -1,45 +0,0 @@ -package org.opensearch.sql.spark.helper; - -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_REGION; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INTEGRATION_JAR; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -class FlintHelperTest { - - private static final String JAR = "JAR"; - private static final String HOST = "HOST"; - private static final String PORT = "PORT"; - private static final String SCHEME = "SCHEME"; - private static final String AUTH = "AUTH"; - private static final String REGION = "REGION"; - - @Test - public void testConstructorWithNull() { - FlintHelper helper = new FlintHelper(null, null, null, null, null, null); - - Assertions.assertEquals(FLINT_INTEGRATION_JAR, helper.getFlintIntegrationJar()); - Assertions.assertEquals(FLINT_DEFAULT_HOST, helper.getFlintHost()); - Assertions.assertEquals(FLINT_DEFAULT_PORT, helper.getFlintPort()); - Assertions.assertEquals(FLINT_DEFAULT_SCHEME, helper.getFlintScheme()); - Assertions.assertEquals(FLINT_DEFAULT_AUTH, helper.getFlintAuth()); - Assertions.assertEquals(FLINT_DEFAULT_REGION, helper.getFlintRegion()); - } - - @Test - public void testConstructor() { - FlintHelper helper = new FlintHelper(JAR, HOST, PORT, SCHEME, AUTH, REGION); - - Assertions.assertEquals(JAR, helper.getFlintIntegrationJar()); - Assertions.assertEquals(HOST, helper.getFlintHost()); - Assertions.assertEquals(PORT, helper.getFlintPort()); - Assertions.assertEquals(SCHEME, helper.getFlintScheme()); - Assertions.assertEquals(AUTH, helper.getFlintAuth()); - Assertions.assertEquals(REGION, helper.getFlintRegion()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java new file mode 100644 index 0000000000..971db3c33c --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.client.SparkClient; + +@ExtendWith(MockitoExtension.class) +public class SparkScanTest { + @Mock private SparkClient sparkClient; + + @Test + @SneakyThrows + void testQueryResponseIteratorForQueryRangeFunction() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + Assertions.assertFalse(sparkScan.hasNext()); + assertNull(sparkScan.next()); + } + + @Test + @SneakyThrows + void testExplain() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + assertEquals("SparkQueryRequest(sql=select 1)", sparkScan.explain()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java new file mode 100644 index 0000000000..5e7ec76cdb --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collection; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageEngineTest { + @Mock private SparkClient client; + + @Test + public void getFunctions() { + SparkStorageEngine engine = new SparkStorageEngine(client); + Collection functionResolverCollection = engine.getFunctions(); + assertNotNull(functionResolverCollection); + assertEquals(1, functionResolverCollection.size()); + assertTrue( + functionResolverCollection.iterator().next() instanceof SparkSqlTableFunctionResolver); + } + + @Test + public void getTable() { + SparkStorageEngine engine = new SparkStorageEngine(client); + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> engine.getTable(new DataSourceSchemaName("spark", "default"), "")); + assertEquals("Unable to get table from storage engine.", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java new file mode 100644 index 0000000000..ebe3c8f3a9 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; + +import java.security.InvalidParameterException; +import java.util.HashMap; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.storage.StorageEngine; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageFactoryTest { + @Mock private Settings settings; + + @Mock private Client client; + + @Test + void testGetConnectorType() { + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + Assertions.assertEquals(DataSourceType.SPARK, sparkStorageFactory.getDataSourceType()); + } + + @Test + @SneakyThrows + void testGetStorageEngine() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + StorageEngine storageEngine = sparkStorageFactory.getStorageEngine(properties); + Assertions.assertTrue(storageEngine instanceof SparkStorageEngine); + } + + @Test + @SneakyThrows + void testInvalidConnectorType() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "random"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + InvalidParameterException exception = + Assertions.assertThrows( + InvalidParameterException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Spark connector type is invalid.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testUnsupportedEmrAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "basic"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Invalid auth type.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingCluster() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthKeys() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthSecretKey() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "test"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", exception.getMessage()); + } + + @Test + void testCreateDataSourceSuccess() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.host", "localhost"); + properties.put("spark.datasource.flint.port", "9200"); + properties.put("spark.datasource.flint.scheme", "http"); + properties.put("spark.datasource.flint.auth", "false"); + properties.put("spark.datasource.flint.region", "us-west-2"); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setName("spark") + .setConnector(DataSourceType.SPARK) + .setProperties(properties) + .build(); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } + + @Test + void testSetSparkJars() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("spark.sql.application", "s3://spark/spark-sql-job.jar"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.integration", "s3://spark/flint-spark-integration.jar"); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setName("spark") + .setConnector(DataSourceType.SPARK) + .setProperties(properties) + .build(); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java new file mode 100644 index 0000000000..a70d4ba69e --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.read.TableScanBuilder; + +@ExtendWith(MockitoExtension.class) +public class SparkTableTest { + @Mock private SparkClient client; + + @Test + void testUnsupportedOperation() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); + + assertThrows(UnsupportedOperationException.class, sparkTable::exists); + assertThrows( + UnsupportedOperationException.class, () -> sparkTable.create(Collections.emptyMap())); + } + + @Test + void testCreateScanBuilderWithSqlTableFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); + TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); + Assertions.assertNotNull(tableScanBuilder); + Assertions.assertTrue(tableScanBuilder instanceof SparkSqlFunctionTableScanBuilder); + } + + @Test + @SneakyThrows + void testGetFieldTypesFromSparkQueryRequest() { + SparkTable sparkTable = new SparkTable(client, new SparkQueryRequest()); + Map expectedFieldTypes = new HashMap<>(); + Map fieldTypes = sparkTable.getFieldTypes(); + + assertEquals(expectedFieldTypes, fieldTypes); + verifyNoMoreInteractions(client); + assertNotNull(sparkTable.getSparkQueryRequest()); + } + + @Test + void testImplementWithSqlFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + SparkTable sparkMetricTable = new SparkTable(client, sparkQueryRequest); + PhysicalPlan plan = + sparkMetricTable.implement(new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); + assertTrue(plan instanceof SparkSqlFunctionTableScanOperator); + } +} diff --git a/spark/src/test/resources/all_data_type.json b/spark/src/test/resources/all_data_type.json new file mode 100644 index 0000000000..a046912319 --- /dev/null +++ b/spark/src/test/resources/all_data_type.json @@ -0,0 +1,22 @@ +{ + "data": { + "result": [ + "{'boolean':true,'long':922337203,'integer':2147483647,'short':32767,'byte':127,'double':9223372036854.775807,'float':21474.83647,'timestamp':'2023-07-01 10:31:30','date':'2023-07-01 10:31:30','string':'ABC','char':'A'}" + ], + "schema": [ + "{'column_name':'boolean','data_type':'boolean'}", + "{'column_name':'long','data_type':'long'}", + "{'column_name':'integer','data_type':'integer'}", + "{'column_name':'short','data_type':'short'}", + "{'column_name':'byte','data_type':'byte'}", + "{'column_name':'double','data_type':'double'}", + "{'column_name':'float','data_type':'float'}", + "{'column_name':'timestamp','data_type':'timestamp'}", + "{'column_name':'date','data_type':'date'}", + "{'column_name':'string','data_type':'string'}", + "{'column_name':'char','data_type':'char'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/issue2210.json b/spark/src/test/resources/issue2210.json new file mode 100644 index 0000000000..dec24efdc2 --- /dev/null +++ b/spark/src/test/resources/issue2210.json @@ -0,0 +1,17 @@ +{ + "data": { + "result": [ + "{'col_name':'day','data_type':'int'}", + "{'col_name':'# Partition Information','data_type':'','comment':''}", + "{'col_name':'# col_name','data_type':'data_type','comment':'comment'}", + "{'col_name':'day','data_type':'int'}" + ], + "schema": [ + "{'column_name':'col_name','data_type':'string'}", + "{'column_name':'data_type','data_type':'string'}", + "{'column_name':'comment','data_type':'string'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/spark_data_type.json b/spark/src/test/resources/spark_data_type.json new file mode 100644 index 0000000000..79bd047f27 --- /dev/null +++ b/spark/src/test/resources/spark_data_type.json @@ -0,0 +1,13 @@ +{ + "data": { + "result": [ + "{'struct_column':{'struct_value':'value'},'array_column':[1,2]}" + ], + "schema": [ + "{'column_name':'struct_column','data_type':'struct'}", + "{'column_name':'array_column','data_type':'array'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/spark_execution_result_test.json b/spark/src/test/resources/spark_execution_result_test.json deleted file mode 100644 index 80d5a49283..0000000000 --- a/spark/src/test/resources/spark_execution_result_test.json +++ /dev/null @@ -1,79 +0,0 @@ -{ - "data" : { - "schema": [ - { - "column_name": "col1", - "data_type": "boolean" - }, - { - "column_name": "col2", - "data_type": "long" - }, - { - "column_name": "col3", - "data_type": "integer" - }, - { - "column_name": "col4", - "data_type": "short" - }, - { - "column_name": "col5", - "data_type": "byte" - }, - { - "column_name": "col6", - "data_type": "double" - }, - { - "column_name": "col7", - "data_type": "float" - }, - { - "column_name": "col8", - "data_type": "timestamp" - }, - { - "column_name": "col9", - "data_type": "date" - }, - { - "column_name": "col10", - "data_type": "string" - }, - { - "column_name": "col11", - "data_type": "other" - }, - { - "column_name": "col12", - "data_type": "other object" - }, - { - "column_name": "col13", - "data_type": "other array" - }, - { - "column_name": "col14", - "data_type": "other" - } - ], - "result": [ - { - "col1": true, - "col2": 2, - "col3": 3, - "col4": 4, - "col5": 5, - "col6": 6.1, - "col7": 7.1, - "col8": "2024-01-02 03:04:05.1234", - "col9": "2024-01-03 04:05:06.1234", - "col10": "some string", - "col11": "other value", - "col12": { "hello": "world" }, - "col13": [1, 2, 3] - } - ] - } -} \ No newline at end of file