diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 710f472acb..c4b5c89540 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -54,7 +54,9 @@ public DispatchQueryResponse dispatch( dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } - List validationErrors = SQLQueryUtils.validateSparkSqlQuery(query); + List validationErrors = + SQLQueryUtils.validateSparkSqlQuery( + dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query); if (!validationErrors.isEmpty()) { throw new IllegalArgumentException( "Query is not allowed: " + String.join(", ", validationErrors)); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index ff08a8f41e..ce3bcab06b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -15,9 +15,13 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.misc.Interval; import org.antlr.v4.runtime.tree.ParseTree; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; @@ -25,6 +29,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.StatementContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; @@ -38,13 +43,14 @@ */ @UtilityClass public class SQLQueryUtils { + private static final Logger logger = LogManager.getLogger(SQLQueryUtils.class); public static List extractFullyQualifiedTableNames(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); - SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + StatementContext statement = sqlBaseParser.statement(); SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); statement.accept(sparkSqlTableNameVisitor); return sparkSqlTableNameVisitor.getFullyQualifiedTableNames(); @@ -77,32 +83,73 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } - public static List validateSparkSqlQuery(String sqlQuery) { - SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor(); + public static List validateSparkSqlQuery(DataSource datasource, String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); try { - SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); - sparkSqlValidatorVisitor.visit(statement); - return sparkSqlValidatorVisitor.getValidationErrors(); - } catch (SyntaxCheckException syntaxCheckException) { + SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource); + StatementContext statement = sqlBaseParser.statement(); + sqlParserBaseVisitor.visit(statement); + return sqlParserBaseVisitor.getValidationErrors(); + } catch (SyntaxCheckException e) { + logger.error( + String.format( + "Failed to parse sql statement context while validating sql query %s", sqlQuery), + e); return Collections.emptyList(); } } - private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor { + private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { + if (datasource != null + && datasource.getConnectorType() != null + && datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) { + return new SparkSqlSecurityLakeValidatorVisitor(); + } else { + return new SparkSqlValidatorVisitor(); + } + } - @Getter private final List validationErrors = new ArrayList<>(); + /** + * A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class + * supports accumulating validation errors on visiting sql statement + */ + @Getter + private static class SqlBaseValidatorVisitor extends SqlBaseParserBaseVisitor { + private final List validationErrors = new ArrayList<>(); + } + /** A generic validator impl for Spark Sql Queries */ + private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor { @Override public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - validationErrors.add("Creating user-defined functions is not allowed"); + getValidationErrors().add("Creating user-defined functions is not allowed"); return super.visitCreateFunction(ctx); } } + /** A validator impl specific to Security Lake for Spark Sql Queries */ + private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor { + + public SparkSqlSecurityLakeValidatorVisitor() { + // only select statement allowed. hence we add the validation error to all types of statements + // by default + // and remove the validation error only for select statement. + getValidationErrors() + .add( + "Unsupported sql statement for security lake data source. Only select queries are" + + " allowed"); + } + + @Override + public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { + getValidationErrors().clear(); + return super.visitStatementDefault(ctx); + } + } + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index fe7777606c..235fe84c70 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; @@ -18,7 +19,10 @@ import lombok.Getter; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @@ -27,6 +31,8 @@ @ExtendWith(MockitoExtension.class) public class SQLQueryUtilsTest { + @Mock private DataSource dataSource; + @Test void testExtractionOfTableNameFromSQLQueries() { String sqlQuery = "select * from my_glue.default.http_logs"; @@ -404,15 +410,96 @@ void testAutoRefresh() { @Test void testValidateSparkSqlQuery_ValidQuery() { - String validQuery = "SELECT * FROM users WHERE age > 18"; - List errors = SQLQueryUtils.validateSparkSqlQuery(validQuery); + List errors = + validateSparkSqlQueryForDataSourceType( + "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'", + DataSourceType.PROMETHEUS); + assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); } + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() { + List errors = + validateSparkSqlQueryForDataSourceType( + "SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() { + List errors = + validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() { + List errors = + validateSparkSqlQueryForDataSourceType( + "SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_nullDatasource() { + List errors = + SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18"); + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + private List validateSparkSqlQueryForDataSourceType( + String query, DataSourceType dataSourceType) { + when(this.dataSource.getConnectorType()).thenReturn(dataSourceType); + + return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query); + } + + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() { + List errors = + validateSparkSqlQueryForDataSourceType( + "REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE); + + assertFalse( + errors.isEmpty(), + "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" + + " queries. Given query was REFRESH sql query"); + assertEquals( + errors.get(0), + "Unsupported sql statement for security lake data source. Only select queries are allowed"); + } + + @Test + void + testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() { + String query = + "CREATE TABLE AccountSummaryOrWhatever AS " + + "select taxid, address1, count(address1) from dbo.t " + + "group by taxid, address1;"; + + List errors = + validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE); + + assertFalse( + errors.isEmpty(), + "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" + + " queries. Given query was REFRESH sql query"); + assertEquals( + errors.get(0), + "Unsupported sql statement for security lake data source. Only select queries are allowed"); + } + @Test void testValidateSparkSqlQuery_InvalidQuery() { + when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS); String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; - List errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery); + + List errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery); + assertFalse(errors.isEmpty(), "Invalid query should produce errors"); assertEquals(1, errors.size(), "Should have one error"); assertEquals( diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index c74964fc00..ac8ae1a5e1 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -7,9 +7,11 @@ import java.util.HashMap; import java.util.Map; +import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; @RequiredArgsConstructor +@EqualsAndHashCode public class DataSourceType { public static DataSourceType PROMETHEUS = new DataSourceType("PROMETHEUS"); public static DataSourceType OPENSEARCH = new DataSourceType("OPENSEARCH");