diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index f54ac49b4e..10fc48727a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -30,6 +30,8 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierExtraContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; @@ -43,6 +45,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LiteralTypeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; @@ -77,7 +80,9 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TypeContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsupportedHiveNativeCommandsContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; /** This visitor validate grammar using GrammarElementValidator */ @@ -584,4 +589,34 @@ private void validateAllowed(SQLGrammarElement element) { throw new IllegalArgumentException(element + " is not allowed."); } } + + @Override + public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) { + ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra(); + if (extra.children != null) { + throw new IllegalArgumentException("Invalid identifier: " + ctx.getText()); + } + return super.visitErrorCapturingIdentifier(ctx); + } + + @Override + public Void visitLiteralType(LiteralTypeContext ctx) { + if (ctx.unsupportedType != null) { + throw new IllegalArgumentException("Unsupported typed literal: " + ctx.getText()); + } + return super.visitLiteralType(ctx); + } + + @Override + public Void visitType(TypeContext ctx) { + if (ctx.unsupportedType != null) { + throw new IllegalArgumentException("Unsupported data type: " + ctx.getText()); + } + return super.visitType(ctx); + } + + @Override + public Void visitUnsupportedHiveNativeCommands(UnsupportedHiveNativeCommandsContext ctx) { + throw new IllegalArgumentException("Unsupported command."); + } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 3e4eef52fd..ad73daa37f 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -571,6 +571,56 @@ void testValidateFlintExtensionQuery() { UUID.randomUUID().toString(), DataSourceType.SECURITY_LAKE)); } + @Test + void testInvalidIdentifier() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + v.ng("SELECT a.b.c as a-b-c FROM abc"); + v.ok("SELECT a.b.c as `a-b-c` FROM abc"); + v.ok("SELECT a.b.c as a_b_c FROM abc"); + + v.ng("SELECT a.b.c FROM a-b-c"); + v.ng("SELECT a.b.c FROM a.b-c"); + v.ok("SELECT a.b.c FROM b.c.`a-b-c`"); + v.ok("SELECT a.b.c FROM `a-b-c`"); + } + + @Test + void testUnsupportedType() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + + v.ng("SELECT cast ( a as DateTime ) FROM tbl"); + v.ok("SELECT cast ( a as DATE ) FROM tbl"); + v.ok("SELECT cast ( a as Date ) FROM tbl"); + v.ok("SELECT cast ( a as Timestamp ) FROM tbl"); + } + + @Test + void testUnsupportedTypedLiteral() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + + v.ng("SELECT DATETIME '2024-10-11'"); + v.ok("SELECT DATE '2024-10-11'"); + v.ok("SELECT TIMESTAMP '2024-10-11'"); + } + + @Test + void testUnsupportedHiveNativeCommand() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + + v.ng("CREATE ROLE aaa"); + v.ng("SHOW GRANT"); + v.ng("EXPORT TABLE"); + v.ng("ALTER TABLE aaa NOT CLUSTERED"); + v.ng("START TRANSACTION"); + v.ng("COMMIT"); + v.ng("ROLLBACK"); + v.ng("DFS"); + } + @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; @@ -580,10 +630,18 @@ public void ok(TestElement query) { runValidate(query.getQueries()); } + public void ok(String query) { + runValidate(query); + } + public void ng(TestElement query) { + Arrays.stream(query.getQueries()).forEach(this::ng); + } + + public void ng(String query) { assertThrows( IllegalArgumentException.class, - () -> runValidate(query.getQueries()), + () -> runValidate(query), "The query should throw: query=`" + query.toString() + "`"); }