diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index f59afe8180..ca5b7bf6e4 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -44,7 +44,7 @@ Sample Request:: curl --location 'http://localhost:9200/_plugins/_async_query' \ --header 'Content-Type: application/json' \ --data '{ - "kind" : "sql", + "lang" : "sql", "query" : "select * from my_glue.default.http_logs limit 10" }' diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index ed10b1e3e6..adb0d21098 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -301,7 +301,10 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService() { JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - sparkJobClient, this.dataSourceService, jobExecutionResponseReader); + sparkJobClient, + this.dataSourceService, + new DataSourceUserAuthorizationHelperImpl(client), + jobExecutionResponseReader); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings); } diff --git a/spark/build.gradle b/spark/build.gradle index fb9a1e0e4b..2bee7408a5 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -7,13 +7,42 @@ plugins { id 'java-library' id "io.freefair.lombok" id 'jacoco' + id 'antlr' } repositories { mavenCentral() } +tasks.register('downloadG4Files', Exec) { + description = 'Download remote .g4 files from GitHub' + + executable 'curl' + +// Need to add these back once the grammar issues with indexName and tableName is addressed in flint integration jar. +// args '-o', 'src/main/antlr/FlintSparkSqlExtensions.g4', 'https://raw.githubusercontent.com/opensearch-project/opensearch-spark/main/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4' +// args '-o', 'src/main/antlr/SparkSqlBase.g4', 'https://raw.githubusercontent.com/opensearch-project/opensearch-spark/main/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4' + args '-o', 'src/main/antlr/SqlBaseParser.g4', 'https://raw.githubusercontent.com/apache/spark/master/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4' + args '-o', 'src/main/antlr/SqlBaseLexer.g4', 'https://raw.githubusercontent.com/apache/spark/master/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4' +} + +generateGrammarSource { + arguments += ['-visitor', '-package', 'org.opensearch.sql.spark.antlr.parser'] + source = sourceSets.main.antlr + outputDirectory = file("build/generated-src/antlr/main/org/opensearch/sql/spark/antlr/parser") +} +configurations { + compile { + extendsFrom = extendsFrom.findAll { it != configurations.antlr } + } +} + +// Make sure the downloadG4File task runs before the generateGrammarSource task +generateGrammarSource.dependsOn downloadG4Files + dependencies { + antlr "org.antlr:antlr4:4.7.1" + api project(':core') implementation project(':protocol') implementation project(':datasources') @@ -46,7 +75,7 @@ jacocoTestReport { } afterEvaluate { classDirectories.setFrom(files(classDirectories.files.collect { - fileTree(dir: it) + fileTree(dir: it, exclude: ['**/antlr/parser/**']) })) } } @@ -61,7 +90,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.rest.*', 'org.opensearch.sql.spark.transport.model.*', 'org.opensearch.sql.spark.asyncquery.model.*', - 'org.opensearch.sql.spark.asyncquery.exceptions.*' + 'org.opensearch.sql.spark.asyncquery.exceptions.*', + 'org.opensearch.sql.spark.dispatcher.model.*' ] limit { counter = 'LINE' @@ -75,7 +105,7 @@ jacocoTestCoverageVerification { } afterEvaluate { classDirectories.setFrom(files(classDirectories.files.collect { - fileTree(dir: it) + fileTree(dir: it, exclude: ['**/antlr/parser/**']) })) } } diff --git a/spark/src/main/antlr/FlintSparkSqlExtensions.g4 b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 new file mode 100644 index 0000000000..2d50fbc49f --- /dev/null +++ b/spark/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +grammar FlintSparkSqlExtensions; + +import SparkSqlBase; + + +// Flint SQL Syntax Extension + +singleStatement + : statement SEMICOLON* EOF + ; + +statement + : skippingIndexStatement + | coveringIndexStatement + ; + +skippingIndexStatement + : createSkippingIndexStatement + | refreshSkippingIndexStatement + | describeSkippingIndexStatement + | dropSkippingIndexStatement + ; + +createSkippingIndexStatement + : CREATE SKIPPING INDEX ON tableName + LEFT_PAREN indexColTypeList RIGHT_PAREN + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +refreshSkippingIndexStatement + : REFRESH SKIPPING INDEX ON tableName + ; + +describeSkippingIndexStatement + : (DESC | DESCRIBE) SKIPPING INDEX ON tableName + ; + +dropSkippingIndexStatement + : DROP SKIPPING INDEX ON tableName + ; + +coveringIndexStatement + : createCoveringIndexStatement + | refreshCoveringIndexStatement + | showCoveringIndexStatement + | describeCoveringIndexStatement + | dropCoveringIndexStatement + ; + +createCoveringIndexStatement + : CREATE INDEX indexName ON tableName + LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN + (WITH LEFT_PAREN propertyList RIGHT_PAREN)? + ; + +refreshCoveringIndexStatement + : REFRESH INDEX indexName ON tableName + ; + +showCoveringIndexStatement + : SHOW (INDEX | INDEXES) ON tableName + ; + +describeCoveringIndexStatement + : (DESC | DESCRIBE) INDEX indexName ON tableName + ; + +dropCoveringIndexStatement + : DROP INDEX indexName ON tableName + ; + +indexColTypeList + : indexColType (COMMA indexColType)* + ; + +indexColType + : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX) + ; + +indexName + : identifier + ; + +tableName + : multipartIdentifier + ; diff --git a/spark/src/main/antlr/SparkSqlBase.g4 b/spark/src/main/antlr/SparkSqlBase.g4 new file mode 100644 index 0000000000..928f63812c --- /dev/null +++ b/spark/src/main/antlr/SparkSqlBase.g4 @@ -0,0 +1,223 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar SparkSqlBase; + +// Copy from Spark 3.3.1 SqlBaseParser.g4 and SqlBaseLexer.g4 + +@members { + /** + * When true, parser should throw ParseExcetion for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (options=propertyList)? + ; + +propertyList + : property (COMMA property)* + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | STRING + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +booleanValue + : TRUE | FALSE + ; + + +multipartIdentifier + : parts+=identifier (DOT parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +nonReserved + : DROP | SKIPPING | INDEX + ; + + +// Flint lexical tokens + +MIN_MAX: 'MIN_MAX'; +SKIPPING: 'SKIPPING'; +VALUE_SET: 'VALUE_SET'; + + +// Spark lexical tokens + +SEMICOLON: ';'; + +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; + + +CREATE: 'CREATE'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DROP: 'DROP'; +FALSE: 'FALSE'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +ON: 'ON'; +PARTITION: 'PARTITION'; +REFRESH: 'REFRESH'; +SHOW: 'SHOW'; +STRING: 'STRING'; +TRUE: 'TRUE'; +WITH: 'WITH'; + + +EQ : '=' | '=='; +MINUS: '-'; + + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; \ No newline at end of file diff --git a/spark/src/main/antlr/SqlBaseLexer.g4 b/spark/src/main/antlr/SqlBaseLexer.g4 new file mode 100644 index 0000000000..d9128de0f5 --- /dev/null +++ b/spark/src/main/antlr/SqlBaseLexer.g4 @@ -0,0 +1,551 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +lexer grammar SqlBaseLexer; + +@members { + /** + * When true, parser should throw ParseException for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + +SEMICOLON: ';'; + +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; +LEFT_BRACKET: '['; +RIGHT_BRACKET: ']'; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`, and +// modify `ParserUtils.toExprAlias()` which assumes all keywords are between `ADD` and `ZONE`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ALWAYS: 'ALWAYS'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ANY_VALUE: 'ANY_VALUE'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BIGINT: 'BIGINT'; +BINARY: 'BINARY'; +BOOLEAN: 'BOOLEAN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +BYTE: 'BYTE'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CATALOG: 'CATALOG'; +CATALOGS: 'CATALOGS'; +CHANGE: 'CHANGE'; +CHAR: 'CHAR'; +CHARACTER: 'CHARACTER'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DAYS: 'DAYS'; +DAYOFYEAR: 'DAYOFYEAR'; +DATA: 'DATA'; +DATE: 'DATE'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES'; +DATEADD: 'DATEADD'; +DATE_ADD: 'DATE_ADD'; +DATEDIFF: 'DATEDIFF'; +DATE_DIFF: 'DATE_DIFF'; +DBPROPERTIES: 'DBPROPERTIES'; +DEC: 'DEC'; +DECIMAL: 'DECIMAL'; +DECLARE: 'DECLARE'; +DEFAULT: 'DEFAULT'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DOUBLE: 'DOUBLE'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXCLUDE: 'EXCLUDE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FLOAT: 'FLOAT'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GENERATED: 'GENERATED'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +BINARY_HEX: 'X'; +HOUR: 'HOUR'; +HOURS: 'HOURS'; +IDENTIFIER_KW: 'IDENTIFIER'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INCLUDE: 'INCLUDE'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INT: 'INT'; +INTEGER: 'INTEGER'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +ILIKE: 'ILIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +LONG: 'LONG'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MICROSECOND: 'MICROSECOND'; +MICROSECONDS: 'MICROSECONDS'; +MILLISECOND: 'MILLISECOND'; +MILLISECONDS: 'MILLISECONDS'; +MINUTE: 'MINUTE'; +MINUTES: 'MINUTES'; +MONTH: 'MONTH'; +MONTHS: 'MONTHS'; +MSCK: 'MSCK'; +NAME: 'NAME'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NANOSECOND: 'NANOSECOND'; +NANOSECONDS: 'NANOSECONDS'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +NUMERIC: 'NUMERIC'; +OF: 'OF'; +OFFSET: 'OFFSET'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTILE_CONT: 'PERCENTILE_CONT'; +PERCENTILE_DISC: 'PERCENTILE_DISC'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUARTER: 'QUARTER'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +REAL: 'REAL'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPEATABLE: 'REPEATABLE'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SECONDS: 'SECONDS'; +SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHORT: 'SHORT'; +SHOW: 'SHOW'; +SINGLE: 'SINGLE'; +SKEWED: 'SKEWED'; +SMALLINT: 'SMALLINT'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +SOURCE: 'SOURCE'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRING: 'STRING'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +SYSTEM_TIME: 'SYSTEM_TIME'; +SYSTEM_VERSION: 'SYSTEM_VERSION'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TARGET: 'TARGET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TIMEDIFF: 'TIMEDIFF'; +TIMESTAMP: 'TIMESTAMP'; +TIMESTAMP_LTZ: 'TIMESTAMP_LTZ'; +TIMESTAMP_NTZ: 'TIMESTAMP_NTZ'; +TIMESTAMPADD: 'TIMESTAMPADD'; +TIMESTAMPDIFF: 'TIMESTAMPDIFF'; +TINYINT: 'TINYINT'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNPIVOT: 'UNPIVOT'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VARCHAR: 'VARCHAR'; +VAR: 'VAR'; +VARIABLE: 'VARIABLE'; +VERSION: 'VERSION'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +VOID: 'VOID'; +WEEK: 'WEEK'; +WEEKS: 'WEEKS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +WITHIN: 'WITHIN'; +YEAR: 'YEAR'; +YEARS: 'YEARS'; +ZONE: 'ZONE'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; +COLON: ':'; +ARROW: '->'; +FAT_ARROW : '=>'; +HENT_START: '/*+'; +HENT_END: '*/'; +QUESTION: '?'; + +STRING_LITERAL + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | 'R\'' (~'\'')* '\'' + | 'R"'(~'"')* '"' + ; + +DOUBLEQUOTED_STRING + :'"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +// NOTE: If you move a numeric literal, you should modify `ParserUtils.toExprAlias()` +// which assumes all numeric literals are between `BIGINT_LITERAL` and `BIGDECIMAL_LITERAL`. + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 new file mode 100644 index 0000000000..6a6d39e96c --- /dev/null +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -0,0 +1,1905 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +parser grammar SqlBaseParser; + +options { tokenVocab = SqlBaseLexer; } + +@members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enabled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; + + /** + * When true, double quoted literals are identifiers rather than STRINGs. + */ + public boolean double_quoted_identifiers = false; +} + +singleStatement + : statement SEMICOLON* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE identifierReference #use + | USE namespace identifierReference #useNamespace + | SET CATALOG (identifier | stringLit) #setCatalog + | CREATE namespace (IF NOT EXISTS)? identifierReference + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) propertyList))* #createNamespace + | ALTER namespace identifierReference + SET (DBPROPERTIES | PROPERTIES) propertyList #setNamespaceProperties + | ALTER namespace identifierReference + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? identifierReference + (RESTRICT | CASCADE)? #dropNamespace + | SHOW namespaces ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=stringLit)? #showNamespaces + | createTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=propertyList))* #createTableLike + | replaceTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) identifierReference)? COMPUTE STATISTICS + (identifier)? #analyzeTables + | ALTER TABLE identifierReference + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE identifierReference + ADD (COLUMN | COLUMNS) + LEFT_PAREN columns=qualifiedColTypeWithPositionList RIGHT_PAREN #addTableColumns + | ALTER TABLE table=identifierReference + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE identifierReference + DROP (COLUMN | COLUMNS) (IF EXISTS)? + LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN #dropTableColumns + | ALTER TABLE identifierReference + DROP (COLUMN | COLUMNS) (IF EXISTS)? + columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=identifierReference + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) identifierReference + SET TBLPROPERTIES propertyList #setTableProperties + | ALTER (TABLE | VIEW) identifierReference + UNSET TBLPROPERTIES (IF EXISTS)? propertyList #unsetTableProperties + | ALTER TABLE table=identifierReference + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=identifierReference partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=identifierReference partitionSpec? + REPLACE COLUMNS + LEFT_PAREN columns=qualifiedColTypeWithPositionList + RIGHT_PAREN #hiveReplaceColumns + | ALTER TABLE identifierReference (partitionSpec)? + SET SERDE stringLit (WITH SERDEPROPERTIES propertyList)? #setTableSerDe + | ALTER TABLE identifierReference (partitionSpec)? + SET SERDEPROPERTIES propertyList #setTableSerDe + | ALTER (TABLE | VIEW) identifierReference ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE identifierReference + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) identifierReference + DROP (IF EXISTS)? partitionSpec (COMMA partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE identifierReference + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE identifierReference RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable + | DROP VIEW (IF EXISTS)? identifierReference #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? identifierReference + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES propertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier (LEFT_PAREN colTypeList RIGHT_PAREN)? tableProvider + (OPTIONS propertyList)? #createTempViewUsing + | ALTER VIEW identifierReference AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + identifierReference AS className=stringLit + (USING resource (COMMA resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction + | DECLARE (OR REPLACE)? VARIABLE? + identifierReference dataType? variableDefaultExpression? #createVariable + | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) identifierReference)? + (LIKE? pattern=stringLit)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)? + LIKE pattern=stringLit partitionSpec? #showTableExtended + | SHOW TBLPROPERTIES table=identifierReference + (LEFT_PAREN key=propertyKey RIGHT_PAREN)? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=identifierReference + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) identifierReference)? + (LIKE? pattern=stringLit)? #showViews + | SHOW PARTITIONS identifierReference partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS ((FROM | IN) ns=identifierReference)? + (LIKE? (legacy=multipartIdentifier | pattern=stringLit))? #showFunctions + | SHOW CREATE TABLE identifierReference (AS SERDE)? #showCreateTable + | SHOW CURRENT namespace #showCurrentNamespace + | SHOW CATALOGS (LIKE? pattern=stringLit)? #showCatalogs + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + identifierReference #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + identifierReference partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace identifierReference IS + comment #commentNamespace + | COMMENT ON TABLE identifierReference IS comment #commentTable + | REFRESH TABLE identifierReference #refreshTable + | REFRESH FUNCTION identifierReference #refreshFunction + | REFRESH (stringLit | .*?) #refreshResource + | CACHE LAZY? TABLE identifierReference + (OPTIONS options=propertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? identifierReference #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=stringLit OVERWRITE? INTO TABLE + identifierReference partitionSpec? #loadData + | TRUNCATE TABLE identifierReference partitionSpec? #truncateTable + | (MSCK)? REPAIR TABLE identifierReference + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable + | op=(ADD | LIST) identifier .*? #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET (VARIABLE | VAR) assignmentList #setVariable + | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + LEFT_PAREN query RIGHT_PAREN #setVariable + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? + identifierReference (USING indexType=identifier)? + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN + (OPTIONS options=propertyList)? #createIndex + | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +timezone + : stringLit + | LOCAL + ; + +configKey + : quotedIdentifier + ; + +configValue + : backQuotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? identifierReference + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE identifierReference + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION stringLit + ; + +commentSpec + : COMMENT stringLit + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF NOT EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable + | INSERT INTO TABLE? identifierReference partitionSpec? (IF NOT EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable + | INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere + | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION LEFT_PAREN partitionVal (COMMA partitionVal)* RIGHT_PAREN + ; + +partitionVal + : identifier (EQ constant)? + | identifier EQ DEFAULT + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +namespaces + : NAMESPACES + | DATABASES + | SCHEMAS + ; + +describeFuncName + : identifierReference + | stringLit + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier (DOT nameParts+=identifier)* + ; + +ctes + : WITH namedQuery (COMMA namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? LEFT_PAREN query RIGHT_PAREN + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=expressionPropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=propertyList))* + ; + +propertyList + : LEFT_PAREN property (COMMA property)* RIGHT_PAREN + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | stringLit + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | stringLit + ; + +expressionPropertyList + : LEFT_PAREN expressionProperty (COMMA expressionProperty)* RIGHT_PAREN + ; + +expressionProperty + : key=propertyKey (EQ? value=expression)? + ; + +constantList + : LEFT_PAREN constant (COMMA constant)* RIGHT_PAREN + ; + +nestedConstantList + : LEFT_PAREN constantList (COMMA constantList)* RIGHT_PAREN + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=stringLit OUTPUTFORMAT outFmt=stringLit #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : stringLit (WITH SERDEPROPERTIES propertyList)? + ; + +resource + : identifier stringLit + ; + +dmlStatementNoWith + : insertInto query #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM identifierReference tableAlias whereClause? #deleteFromTable + | UPDATE identifierReference tableAlias setClause whereClause? #updateTable + | MERGE INTO target=identifierReference targetAlias=tableAlias + USING (source=identifierReference | + LEFT_PAREN sourceQuery=query RIGHT_PAREN) sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* + notMatchedBySourceClause* #mergeIntoTable + ; + +identifierReference + : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | multipartIdentifier + ; + +queryOrganization + : (ORDER BY order+=sortItem (COMMA order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (COMMA distributeBy+=expression)*)? + (SORT BY sort+=sortItem (COMMA sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + (OFFSET offset=expression)? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enabled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE identifierReference #table + | inlineTable #inlineTableDefault1 + | LEFT_PAREN query RIGHT_PAREN #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM LEFT_PAREN setQuantifier? expressionSeq RIGHT_PAREN + | kind=MAP setQuantifier? expressionSeq + | kind=REDUCE setQuantifier? expressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=stringLit)? + USING script=stringLit + (AS (identifierSeq | colTypeList | (LEFT_PAREN (identifierSeq | colTypeList) RIGHT_PAREN)))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=stringLit)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (BY TARGET)? (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +notMatchedBySourceClause + : WHEN NOT MATCHED BY SOURCE (AND notMatchedBySourceCond=booleanExpression)? THEN notMatchedBySourceAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN + VALUES LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + ; + +notMatchedBySourceAction + : DELETE + | UPDATE SET assignmentList + ; + +assignmentList + : assignment (COMMA assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : HENT_START hintStatements+=hintStatement (COMMA? hintStatements+=hintStatement)* HENT_END + ; + +hintStatement + : hintName=identifier + | hintName=identifier LEFT_PAREN parameters+=primaryExpression (COMMA parameters+=primaryExpression)* RIGHT_PAREN + ; + +fromClause + : FROM relation (COMMA relation)* lateralView* pivotClause? unpivotClause? + ; + +temporalClause + : FOR? (SYSTEM_VERSION | VERSION) AS OF version + | FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + ; + +aggregationClause + : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause + (COMMA groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (COMMA groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN)? + ; + +groupByClause + : groupingAnalytics + | expression + ; + +groupingAnalytics + : (ROLLUP | CUBE) LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN + | GROUPING SETS LEFT_PAREN groupingElement (COMMA groupingElement)* RIGHT_PAREN + ; + +groupingElement + : groupingAnalytics + | groupingSet + ; + +groupingSet + : LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN + | expression + ; + +pivotClause + : PIVOT LEFT_PAREN aggregates=namedExpressionSeq FOR pivotColumn IN LEFT_PAREN pivotValues+=pivotValue (COMMA pivotValues+=pivotValue)* RIGHT_PAREN RIGHT_PAREN + ; + +pivotColumn + : identifiers+=identifier + | LEFT_PAREN identifiers+=identifier (COMMA identifiers+=identifier)* RIGHT_PAREN + ; + +pivotValue + : expression (AS? identifier)? + ; + +unpivotClause + : UNPIVOT nullOperator=unpivotNullClause? LEFT_PAREN + operator=unpivotOperator + RIGHT_PAREN (AS? identifier)? + ; + +unpivotNullClause + : (INCLUDE | EXCLUDE) NULLS + ; + +unpivotOperator + : (unpivotSingleValueColumnClause | unpivotMultiValueColumnClause) + ; + +unpivotSingleValueColumnClause + : unpivotValueColumn FOR unpivotNameColumn IN LEFT_PAREN unpivotColumns+=unpivotColumnAndAlias (COMMA unpivotColumns+=unpivotColumnAndAlias)* RIGHT_PAREN + ; + +unpivotMultiValueColumnClause + : LEFT_PAREN unpivotValueColumns+=unpivotValueColumn (COMMA unpivotValueColumns+=unpivotValueColumn)* RIGHT_PAREN + FOR unpivotNameColumn + IN LEFT_PAREN unpivotColumnSets+=unpivotColumnSet (COMMA unpivotColumnSets+=unpivotColumnSet)* RIGHT_PAREN + ; + +unpivotColumnSet + : LEFT_PAREN unpivotColumns+=unpivotColumn (COMMA unpivotColumns+=unpivotColumn)* RIGHT_PAREN unpivotAlias? + ; + +unpivotValueColumn + : identifier + ; + +unpivotNameColumn + : identifier + ; + +unpivotColumnAndAlias + : unpivotColumn unpivotAlias? + ; + +unpivotColumn + : multipartIdentifier + ; + +unpivotAlias + : AS? identifier + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN tblName=identifier (AS? colName+=identifier (COMMA colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : LATERAL? relationPrimary relationExtension* + ; + +relationExtension + : joinRelation + | pivotClause + | unpivotClause + ; + +joinRelation + : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + | NATURAL joinType JOIN LATERAL? right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE LEFT_PAREN sampleMethod? RIGHT_PAREN (REPEATABLE LEFT_PAREN seed=INTEGER_VALUE RIGHT_PAREN)? + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName LEFT_PAREN RIGHT_PAREN))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : LEFT_PAREN identifierSeq RIGHT_PAREN + ; + +identifierSeq + : ident+=errorCapturingIdentifier (COMMA ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : LEFT_PAREN orderedIdentifier (COMMA orderedIdentifier)* RIGHT_PAREN + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : LEFT_PAREN identifierComment (COMMA identifierComment)* RIGHT_PAREN + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : identifierReference temporalClause? + sample? tableAlias #tableName + | LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery + | LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (COMMA expression)* tableAlias + ; + +functionTableSubqueryArgument + : TABLE identifierReference tableArgumentPartitioning? + | TABLE LEFT_PAREN identifierReference RIGHT_PAREN tableArgumentPartitioning? + | TABLE LEFT_PAREN query RIGHT_PAREN tableArgumentPartitioning? + ; + +tableArgumentPartitioning + : ((WITH SINGLE PARTITION) + | ((PARTITION | DISTRIBUTE) BY + (((LEFT_PAREN partition+=expression (COMMA partition+=expression)* RIGHT_PAREN)) + | partition+=expression))) + ((ORDER | SORT) BY + (((LEFT_PAREN sortItem (COMMA sortItem)* RIGHT_PAREN) + | sortItem)))? + ; + +functionTableNamedArgumentExpression + : key=identifier FAT_ARROW table=functionTableSubqueryArgument + ; + +functionTableReferenceArgument + : functionTableSubqueryArgument + | functionTableNamedArgumentExpression + ; + +functionTableArgument + : functionTableReferenceArgument + | functionArgument + ; + +functionTable + : funcName=functionName LEFT_PAREN + (functionTableArgument (COMMA functionTableArgument)*)? + RIGHT_PAREN tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=stringLit (WITH SERDEPROPERTIES props=propertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=stringLit (ESCAPED BY escapedBy=stringLit)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=stringLit)? + (MAP KEYS TERMINATED BY keysTerminatedBy=stringLit)? + (LINES TERMINATED BY linesSeparatedBy=stringLit)? + (NULL DEFINED AS nullDefinedAs=stringLit)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (COMMA multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier (DOT parts+=errorCapturingIdentifier)* + ; + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (OPTIONS options=propertyList)? + ; + +tableIdentifier + : (db=errorCapturingIdentifier DOT)? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier DOT)? function=errorCapturingIdentifier + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (COMMA namedExpression)* + ; + +partitionFieldList + : LEFT_PAREN fields+=partitionField (COMMA fields+=partitionField)* RIGHT_PAREN + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + LEFT_PAREN argument+=transformArgument (COMMA argument+=transformArgument)* RIGHT_PAREN #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +namedArgumentExpression + : key=identifier FAT_ARROW value=expression + ; + +functionArgument + : expression + | namedArgumentExpression + ; + +expressionSeq + : expression (COMMA expression)* + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS LEFT_PAREN query RIGHT_PAREN #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + | NOT? kind=IN LEFT_PAREN query RIGHT_PAREN + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) (LEFT_PAREN RIGHT_PAREN | LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN) + | NOT? kind=(LIKE | ILIKE) pattern=valueExpression (ESCAPE escapeChar=stringLit)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +datetimeUnit + : YEAR | QUARTER | MONTH + | WEEK | DAY | DAYOFYEAR + | HOUR | MINUTE | SECOND | MILLISECOND | MICROSECOND + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER | SESSION_USER) #currentLike + | name=(TIMESTAMPADD | DATEADD | DATE_ADD) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA unitsAmount=valueExpression COMMA timestamp=valueExpression RIGHT_PAREN #timestampadd + | name=(TIMESTAMPDIFF | DATEDIFF | DATE_DIFF | TIMEDIFF) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA startTimestamp=valueExpression COMMA endTimestamp=valueExpression RIGHT_PAREN #timestampdiff + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct + | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first + | ANY_VALUE LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #any_value + | LAST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #last + | POSITION LEFT_PAREN substr=valueExpression IN str=valueExpression RIGHT_PAREN #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName DOT ASTERISK #star + | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor + | LEFT_PAREN query RIGHT_PAREN #subqueryExpression + | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause + | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument + (COMMA argument+=functionArgument)*)? RIGHT_PAREN + (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? + (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall + | identifier ARROW expression #lambda + | LEFT_PAREN identifier (COMMA identifier)+ RIGHT_PAREN ARROW expression #lambda + | value=primaryExpression LEFT_BRACKET index=valueExpression RIGHT_BRACKET #subscript + | identifier #columnReference + | base=primaryExpression DOT fieldName=identifier #dereference + | LEFT_PAREN expression RIGHT_PAREN #parenthesizedExpression + | EXTRACT LEFT_PAREN field=identifier FROM source=valueExpression RIGHT_PAREN #extract + | (SUBSTR | SUBSTRING) LEFT_PAREN str=valueExpression (FROM | COMMA) pos=valueExpression + ((FOR | COMMA) len=valueExpression)? RIGHT_PAREN #substring + | TRIM LEFT_PAREN trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression RIGHT_PAREN #trim + | OVERLAY LEFT_PAREN input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay + | name=(PERCENTILE_CONT | PERCENTILE_DISC) LEFT_PAREN percentage=valueExpression RIGHT_PAREN + WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN + (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile + ; + +literalType + : DATE + | TIMESTAMP | TIMESTAMP_LTZ | TIMESTAMP_NTZ + | INTERVAL + | BINARY_HEX + | unsupportedType=identifier + ; + +constant + : NULL #nullLiteral + | QUESTION #posParameterLiteral + | COLON identifier #namedParameterLiteral + | interval #intervalLiteral + | literalType stringLit #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | stringLit+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval) + ; + +errorCapturingMultiUnitsInterval + : body=multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=unitInMultiUnits)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=unitInUnitToUnit TO to=unitInUnitToUnit + ; + +intervalValue + : (PLUS | MINUS)? + (INTEGER_VALUE | DECIMAL_VALUE | stringLit) + ; + +unitInMultiUnits + : NANOSECOND | NANOSECONDS | MICROSECOND | MICROSECONDS | MILLISECOND | MILLISECONDS + | SECOND | SECONDS | MINUTE | MINUTES | HOUR | HOURS | DAY | DAYS | WEEK | WEEKS + | MONTH | MONTHS | YEAR | YEARS + ; + +unitInUnitToUnit + : SECOND | MINUTE | HOUR | DAY | MONTH | YEAR + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +type + : BOOLEAN + | TINYINT | BYTE + | SMALLINT | SHORT + | INT | INTEGER + | BIGINT | LONG + | FLOAT | REAL + | DOUBLE + | DATE + | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ + | STRING + | CHARACTER | CHAR + | VARCHAR + | BINARY + | DECIMAL | DEC | NUMERIC + | VOID + | INTERVAL + | ARRAY | STRUCT | MAP + | unsupportedType=identifier + ; + +dataType + : complex=ARRAY LT dataType GT #complexDataType + | complex=MAP LT dataType COMMA dataType GT #complexDataType + | complex=STRUCT (LT complexColTypeList? GT | NEQ) #complexDataType + | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType + | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) + (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType + | type (LEFT_PAREN INTEGER_VALUE + (COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (COMMA qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType colDefinitionDescriptorWithPosition* + ; + +colDefinitionDescriptorWithPosition + : NOT NULL + | defaultExpression + | commentSpec + | colPosition + ; + +defaultExpression + : DEFAULT expression + ; + +variableDefaultExpression + : (DEFAULT | EQ) expression + ; + +colTypeList + : colType (COMMA colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +createOrReplaceTableColTypeList + : createOrReplaceTableColType (COMMA createOrReplaceTableColType)* + ; + +createOrReplaceTableColType + : colName=errorCapturingIdentifier dataType colDefinitionOption* + ; + +colDefinitionOption + : NOT NULL + | defaultExpression + | generationExpression + | commentSpec + ; + +generationExpression + : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN + ; + +complexColTypeList + : complexColType (COMMA complexColType)* + ; + +complexColType + : identifier COLON? dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (COMMA namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | LEFT_PAREN name=errorCapturingIdentifier RIGHT_PAREN #windowRef + | LEFT_PAREN + ( CLUSTER BY partition+=expression (COMMA partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (COMMA partition+=expression)*)? + ((ORDER | SORT) BY sortItem (COMMA sortItem)*)?) + windowFrame? + RIGHT_PAREN #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (COMMA qualifiedName)* + ; + +functionName + : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier (DOT identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + | {double_quoted_identifiers}? DOUBLEQUOTED_STRING + ; + +backQuotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + | SET defaultExpression + | dropDefault=DROP DEFAULT + ; + +stringLit + : STRING_LITERAL + | {!double_quoted_identifiers}? DOUBLEQUOTED_STRING + ; + +comment + : stringLit + | NULL + ; + +version + : INTEGER_VALUE + | stringLit + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ALWAYS + | ANALYZE + | ANTI + | ANY_VALUE + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BIGINT + | BINARY + | BINARY_HEX + | BOOLEAN + | BUCKET + | BUCKETS + | BY + | BYTE + | CACHE + | CASCADE + | CATALOG + | CATALOGS + | CHANGE + | CHAR + | CHARACTER + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DATE + | DATEADD + | DATE_ADD + | DATEDIFF + | DATE_DIFF + | DAY + | DAYS + | DAYOFYEAR + | DBPROPERTIES + | DEC + | DECIMAL + | DECLARE + | DEFAULT + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DOUBLE + | DROP + | ESCAPED + | EXCHANGE + | EXCLUDE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FLOAT + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GENERATED + | GLOBAL + | GROUPING + | HOUR + | HOURS + | IDENTIFIER_KW + | IF + | IGNORE + | IMPORT + | INCLUDE + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INT + | INTEGER + | INTERVAL + | ITEMS + | KEYS + | LAST + | LAZY + | LIKE + | ILIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | LONG + | MACRO + | MAP + | MATCHED + | MERGE + | MICROSECOND + | MICROSECONDS + | MILLISECOND + | MILLISECONDS + | MINUTE + | MINUTES + | MONTH + | MONTHS + | MSCK + | NAME + | NAMESPACE + | NAMESPACES + | NANOSECOND + | NANOSECONDS + | NO + | NULLS + | NUMERIC + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUARTER + | QUERY + | RANGE + | REAL + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPEATABLE + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SCHEMAS + | SECOND + | SECONDS + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHORT + | SHOW + | SINGLE + | SKEWED + | SMALLINT + | SORT + | SORTED + | SOURCE + | START + | STATISTICS + | STORED + | STRATIFY + | STRING + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | SYSTEM_TIME + | SYSTEM_VERSION + | TABLES + | TABLESAMPLE + | TARGET + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TIMEDIFF + | TIMESTAMP + | TIMESTAMP_LTZ + | TIMESTAMP_NTZ + | TIMESTAMPADD + | TIMESTAMPDIFF + | TINYINT + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNPIVOT + | UNSET + | UPDATE + | USE + | VALUES + | VARCHAR + | VAR + | VARIABLE + | VERSION + | VIEW + | VIEWS + | VOID + | WEEK + | WEEKS + | WINDOW + | YEAR + | YEARS + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LATERAL + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ALWAYS + | ANALYZE + | AND + | ANY + | ANY_VALUE + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BIGINT + | BINARY + | BINARY_HEX + | BOOLEAN + | BOTH + | BUCKET + | BUCKETS + | BY + | BYTE + | CACHE + | CASCADE + | CASE + | CAST + | CATALOG + | CATALOGS + | CHANGE + | CHAR + | CHARACTER + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DATE + | DATEADD + | DATE_ADD + | DATEDIFF + | DATE_DIFF + | DAY + | DAYS + | DAYOFYEAR + | DBPROPERTIES + | DEC + | DECIMAL + | DECLARE + | DEFAULT + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DOUBLE + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXCLUDE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FLOAT + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GENERATED + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | HOUR + | HOURS + | IDENTIFIER_KW + | IF + | IGNORE + | IMPORT + | IN + | INCLUDE + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INT + | INTEGER + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LAZY + | LEADING + | LIKE + | LONG + | ILIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | LONG + | MACRO + | MAP + | MATCHED + | MERGE + | MICROSECOND + | MICROSECONDS + | MILLISECOND + | MILLISECONDS + | MINUTE + | MINUTES + | MONTH + | MONTHS + | MSCK + | NAME + | NAMESPACE + | NAMESPACES + | NANOSECOND + | NANOSECONDS + | NO + | NOT + | NULL + | NULLS + | NUMERIC + | OF + | OFFSET + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTILE_CONT + | PERCENTILE_DISC + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUARTER + | QUERY + | RANGE + | REAL + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPEATABLE + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SCHEMAS + | SECOND + | SECONDS + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHORT + | SHOW + | SINGLE + | SKEWED + | SMALLINT + | SOME + | SORT + | SORTED + | SOURCE + | START + | STATISTICS + | STORED + | STRATIFY + | STRING + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | SYSTEM_TIME + | SYSTEM_VERSION + | TABLE + | TABLES + | TABLESAMPLE + | TARGET + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TIMEDIFF + | TIMESTAMP + | TIMESTAMP_LTZ + | TIMESTAMP_NTZ + | TIMESTAMPADD + | TIMESTAMPDIFF + | TINYINT + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNPIVOT + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VARCHAR + | VAR + | VARIABLE + | VERSION + | VIEW + | VIEWS + | VOID + | WEEK + | WEEKS + | WHEN + | WHERE + | WINDOW + | WITH + | WITHIN + | YEAR + | YEARS + | ZONE +//--DEFAULT-NON-RESERVED-END + ; diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index efc23e08b5..efb07914e7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -22,6 +22,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -62,9 +63,11 @@ public CreateAsyncQueryResponse createAsyncQuery( sparkExecutionEngineConfigString)); String jobId = sparkQueryDispatcher.dispatch( - sparkExecutionEngineConfig.getApplicationId(), - createAsyncQueryRequest.getQuery(), - sparkExecutionEngineConfig.getExecutionRoleARN()); + new DispatchQueryRequest( + sparkExecutionEngineConfig.getApplicationId(), + createAsyncQueryRequest.getQuery(), + createAsyncQueryRequest.getLang(), + sparkExecutionEngineConfig.getExecutionRoleARN())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); return new CreateAsyncQueryResponse(jobId); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 2377b2f5da..796dae1a9d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -33,24 +33,21 @@ public EmrServerlessClientImpl(AWSEMRServerless emrServerless) { } @Override - public String startJobRun( - String query, - String jobName, - String applicationId, - String executionRoleArn, - String sparkSubmitParams) { + public String startJobRun(StartJobRequest startJobRequest) { StartJobRunRequest request = new StartJobRunRequest() - .withName(jobName) - .withApplicationId(applicationId) - .withExecutionRoleArn(executionRoleArn) + .withName(startJobRequest.getJobName()) + .withApplicationId(startJobRequest.getApplicationId()) + .withExecutionRoleArn(startJobRequest.getExecutionRoleArn()) + .withTags(startJobRequest.getTags()) .withJobDriver( new JobDriver() .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments(query, SPARK_RESPONSE_BUFFER_INDEX_NAME) - .withSparkSubmitParameters(sparkSubmitParams))); + .withEntryPointArguments( + startJobRequest.getQuery(), SPARK_RESPONSE_BUFFER_INDEX_NAME) + .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = AccessController.doPrivileged( (PrivilegedAction) () -> emrServerless.startJobRun(request)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java index c6b3059c77..762c94791a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java @@ -12,12 +12,7 @@ public interface SparkJobClient { - String startJobRun( - String query, - String jobName, - String applicationId, - String executionRoleArn, - String sparkSubmitParams); + String startJobRun(StartJobRequest startJobRequest); GetJobRunResult getJobRunResult(String applicationId, String jobId); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java new file mode 100644 index 0000000000..c0815657d1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import java.util.Map; +import lombok.Data; + +@Data +public class StartJobRequest { + private final String query; + private final String jobName; + private final String applicationId; + private final String executionRoleArn; + private final String sparkSubmitParams; + private final Map tags; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 442838331f..8c25fcc160 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -20,13 +20,23 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import java.net.URI; import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; import lombok.AllArgsConstructor; import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.asyncquery.model.S3GlueSparkSubmitParameters; import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.utils.SQLQueryUtils; /** This class takes care of understanding query and dispatching job query to emr serverless. */ @AllArgsConstructor @@ -36,22 +46,12 @@ public class SparkQueryDispatcher { private DataSourceService dataSourceService; + private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; + private JobExecutionResponseReader jobExecutionResponseReader; - public String dispatch(String applicationId, String query, String executionRoleARN) { - String datasourceName = getDataSourceName(); - try { - return sparkJobClient.startJobRun( - query, - "flint-opensearch-query", - applicationId, - executionRoleARN, - constructSparkParameters(datasourceName)); - } catch (URISyntaxException e) { - throw new IllegalArgumentException( - String.format( - "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); - } + public String dispatch(DispatchQueryRequest dispatchQueryRequest) { + return sparkJobClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); } // TODO : Fetch from Result Index and then make call to EMR Serverless. @@ -70,19 +70,29 @@ public String cancelJob(String applicationId, String jobId) { return cancelJobRunResult.getJobRunId(); } - // TODO: Analyze given query - // Extract datasourceName - // Apply Authorizaiton. - private String getDataSourceName() { - return "my_glue"; + private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) { + if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { + if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) + return getStartJobRequestForIndexRequest(dispatchQueryRequest); + else { + return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); + } + } + throw new UnsupportedOperationException( + String.format("UnSupported Lang type:: %s", dispatchQueryRequest.getLangType())); } - // TODO: Analyze given query and get the role arn based on datasource type. private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { - return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + if (DataSourceType.S3GLUE.equals(dataSourceMetadata.getConnector())) { + return dataSourceMetadata.getProperties().get("glue.auth.role_arn"); + } + throw new UnsupportedOperationException( + String.format( + "UnSupported datasource type for async queries:: %s", + dataSourceMetadata.getConnector())); } - private String constructSparkParameters(String datasourceName) throws URISyntaxException { + private String constructSparkParameters(String datasourceName) { DataSourceMetadata dataSourceMetadata = dataSourceService.getRawDataSourceMetadata(datasourceName); S3GlueSparkSubmitParameters s3GlueSparkSubmitParameters = new S3GlueSparkSubmitParameters(); @@ -93,7 +103,14 @@ private String constructSparkParameters(String datasourceName) throws URISyntaxE s3GlueSparkSubmitParameters.addParameter( HIVE_METASTORE_GLUE_ARN_KEY, getDataSourceRoleARN(dataSourceMetadata)); String opensearchuri = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.uri"); - URI uri = new URI(opensearchuri); + URI uri; + try { + uri = new URI(opensearchuri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } String auth = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.auth"); String region = dataSourceMetadata.getProperties().get("glue.indexstore.opensearch.region"); s3GlueSparkSubmitParameters.addParameter(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); @@ -106,4 +123,69 @@ private String constructSparkParameters(String datasourceName) throws URISyntaxE "spark.sql.catalog." + datasourceName, FLINT_DELEGATE_CATALOG); return s3GlueSparkSubmitParameters.toString(); } + + private StartJobRequest getStartJobRequestForNonIndexQueries( + DispatchQueryRequest dispatchQueryRequest) { + StartJobRequest startJobRequest; + FullyQualifiedTableName fullyQualifiedTableName = + SQLQueryUtils.extractFullyQualifiedTableName(dispatchQueryRequest.getQuery()); + if (fullyQualifiedTableName == null || fullyQualifiedTableName.getDatasourceName() == null) { + throw new UnsupportedOperationException("Queries without a datasource are not supported"); + } + Map tags = new HashMap<>(); + dataSourceUserAuthorizationHelper.authorizeDataSource( + this.dataSourceService.getRawDataSourceMetadata( + fullyQualifiedTableName.getDatasourceName())); + String jobName = + fullyQualifiedTableName.getDatasourceName() + + "-" + + fullyQualifiedTableName.getSchemaName() + + "-" + + fullyQualifiedTableName.getTableName(); + tags.put("datasource", fullyQualifiedTableName.getDatasourceName()); + tags.put("table", fullyQualifiedTableName.getTableName()); + startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + tags); + return startJobRequest; + } + + private StartJobRequest getStartJobRequestForIndexRequest( + DispatchQueryRequest dispatchQueryRequest) { + StartJobRequest startJobRequest; + IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + if (fullyQualifiedTableName.getDatasourceName() == null) { + throw new UnsupportedOperationException("Queries without a datasource are not supported"); + } + dataSourceUserAuthorizationHelper.authorizeDataSource( + this.dataSourceService.getRawDataSourceMetadata( + fullyQualifiedTableName.getDatasourceName())); + String jobName = + fullyQualifiedTableName.getDatasourceName() + + "-" + + fullyQualifiedTableName.getSchemaName() + + "-" + + fullyQualifiedTableName.getTableName() + + "-" + + indexDetails.getIndexName(); + Map tags = new HashMap<>(); + tags.put("index", indexDetails.getIndexName()); + tags.put("datasource", fullyQualifiedTableName.getDatasourceName()); + tags.put("table", fullyQualifiedTableName.getTableName()); + startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + tags); + return startJobRequest; + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java new file mode 100644 index 0000000000..8c921440a0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import lombok.Data; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class DispatchQueryRequest { + private final String applicationId; + private final String query; + private final LangType langType; + private final String executionRoleARN; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java new file mode 100644 index 0000000000..296a1fe579 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/FullyQualifiedTableName.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import com.google.gson.Gson; +import java.util.Arrays; +import lombok.Data; + +@Data +public class FullyQualifiedTableName { + private String datasourceName; + private String schemaName; + private String tableName; + + public FullyQualifiedTableName(String fullyQualifiedName) { + String[] parts = fullyQualifiedName.split("\\."); + if (parts.length >= 3) { + datasourceName = parts[0]; + schemaName = parts[1]; + tableName = String.join(".", Arrays.copyOfRange(parts, 2, parts.length)); + } else if (parts.length == 2) { + schemaName = parts[0]; + tableName = parts[1]; + } else if (parts.length == 1) { + tableName = parts[0]; + } + } + + @Override + public String toString() { + return new Gson().toJson(this); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java new file mode 100644 index 0000000000..ca511351af --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import com.google.gson.Gson; +import lombok.Data; + +@Data +public class IndexDetails { + private String indexName; + private FullyQualifiedTableName fullyQualifiedTableName; + + @Override + public String toString() { + return new Gson().toJson(this); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 1e46ae48d2..c1ad979877 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -17,24 +17,27 @@ public class CreateAsyncQueryRequest { private String query; - private String lang; + private LangType lang; public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; - String lang = null; + LangType lang = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); if (fieldName.equals("query")) { query = parser.textOrNull(); - } else if (fieldName.equals("kind")) { - lang = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + lang = LangType.fromString(parser.textOrNull()); } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } + if (lang == null || query == null) { + throw new IllegalArgumentException("lang and query are required fields."); + } return new CreateAsyncQueryRequest(query, lang); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java new file mode 100644 index 0000000000..2e5a37bf4b --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/LangType.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.rest.model; + +public enum LangType { + SQL("sql"), + PPL("ppl"); + private final String text; + + LangType(String text) { + this.text = text; + } + + public String getText() { + return this.text; + } + + /** + * Get LangType from text. + * + * @param text text. + * @return LangType {@link LangType}. + */ + public static LangType fromString(String text) { + for (LangType langType : LangType.values()) { + if (langType.text.equalsIgnoreCase(text)) { + return langType; + } + } + throw new IllegalArgumentException("No LangType with text " + text + " found"); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java new file mode 100644 index 0000000000..e293e925e0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import lombok.Getter; +import lombok.experimental.UtilityClass; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; +import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +@UtilityClass +public class SQLQueryUtils { + + public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) { + return trySparkSQL(sqlQuery); + } + + public static IndexDetails extractIndexDetails(String sqlQuery) { + return tryFlintExtensionSQL(sqlQuery); + } + + public static boolean isIndexQuery(String sqlQuery) { + FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = + new FlintSparkSqlExtensionsParser( + new CommonTokenStream( + new FlintSparkSqlExtensionsLexer(new CaseInsensitiveCharStream(sqlQuery)))); + flintSparkSqlExtensionsParser.addErrorListener(new SyntaxAnalysisErrorListener()); + try { + flintSparkSqlExtensionsParser.statement(); + return true; + } catch (SyntaxCheckException syntaxCheckException) { + return false; + } + } + + private static FullyQualifiedTableName trySparkSQL(String sqlQuery) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); + statement.accept(sparkSqlTableNameVisitor); + return sparkSqlTableNameVisitor.getFullyQualifiedTableName(); + } + + private static IndexDetails tryFlintExtensionSQL(String sql) { + FlintSparkSqlExtensionsParser flintSparkSqlExtensionsParser = + new FlintSparkSqlExtensionsParser( + new CommonTokenStream( + new FlintSparkSqlExtensionsLexer(new CaseInsensitiveCharStream(sql)))); + flintSparkSqlExtensionsParser.addErrorListener(new SyntaxAnalysisErrorListener()); + FlintSparkSqlExtensionsParser.StatementContext statementContext = + flintSparkSqlExtensionsParser.statement(); + FlintSQLIndexDetailsVisitor flintSQLIndexDetailsVisitor = new FlintSQLIndexDetailsVisitor(); + statementContext.accept(flintSQLIndexDetailsVisitor); + return flintSQLIndexDetailsVisitor.getIndexDetails(); + } + + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { + + @Getter private FullyQualifiedTableName fullyQualifiedTableName; + + @Override + public Void visitTableName(SqlBaseParser.TableNameContext ctx) { + fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText()); + return super.visitTableName(ctx); + } + + // Extract table name from drop table. + @Override + public Void visitDropTable(SqlBaseParser.DropTableContext ctx) { + for (ParseTree parseTree : ctx.children) { + if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { + fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + } + } + return super.visitDropTable(ctx); + } + + // Extract table name for create Table Statement. + @Override + public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { + for (ParseTree parseTree : ctx.children) { + if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { + fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + } + } + return super.visitCreateTableHeader(ctx); + } + } + + public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { + + @Getter private final IndexDetails indexDetails; + + public FlintSQLIndexDetailsVisitor() { + this.indexDetails = new IndexDetails(); + } + + @Override + public Void visitIndexName(FlintSparkSqlExtensionsParser.IndexNameContext ctx) { + indexDetails.setIndexName(ctx.getText()); + return super.visitIndexName(ctx); + } + + @Override + public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { + indexDetails.setFullyQualifiedTableName(new FullyQualifiedTableName(ctx.getText())); + return super.visitTableName(ctx); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 5e832777fc..f9e7a99982 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -28,8 +28,10 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; @ExtendWith(MockitoExtension.class) public class AsyncQueryExecutorServiceImplTest { @@ -44,14 +46,16 @@ void testCreateAsyncQuery() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", "sql"); + new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", LangType.SQL); when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) .thenReturn( "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); when(sparkQueryDispatcher.dispatch( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "arn:aws:iam::270824043731:role/emr-job-execution-role")) + new DispatchQueryRequest( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + LangType.SQL, + "arn:aws:iam::270824043731:role/emr-job-execution-role"))) .thenReturn(EMR_JOB_ID); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); @@ -60,9 +64,11 @@ void testCreateAsyncQuery() { verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); verify(sparkQueryDispatcher, times(1)) .dispatch( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "arn:aws:iam::270824043731:role/emr-job-execution-role"); + new DispatchQueryRequest( + "00fd775baqpu4g0p", + "select * from my_glue.default.http_logs", + LangType.SQL, + "arn:aws:iam::270824043731:role/emr-job-execution-role")); Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 925ee73bcd..1fd3e9bfd2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -20,6 +20,7 @@ import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.StartJobRunResult; import com.amazonaws.services.emrserverless.model.ValidationException; +import java.util.HashMap; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -37,7 +38,13 @@ void testStartJobRun() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.startJobRun( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, SPARK_SUBMIT_PARAMETERS); + new StartJobRequest( + QUERY, + EMRS_JOB_NAME, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + SPARK_SUBMIT_PARAMETERS, + new HashMap<>())); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2000eeefed..03ea38d547 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -12,7 +13,6 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -31,8 +31,12 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.client.SparkJobClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; +import org.opensearch.sql.spark.rest.model.LangType; @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { @@ -40,50 +44,246 @@ public class SparkQueryDispatcherTest { @Mock private SparkJobClient sparkJobClient; @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; + @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Test - void testDispatch() { + void testDispatchSelectQuery() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("table", "http_logs"); + String query = "select * from my_glue.default.http_logs"; when(sparkJobClient.startJobRun( - QUERY, - "flint-opensearch-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString())) + new StartJobRequest( + query, + "my_glue-default-http_logs", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) .thenReturn(EMR_JOB_ID); - when(dataSourceService.getRawDataSourceMetadata("my_glue")) - .thenReturn(constructMyGlueDataSourceMetadata()); - String jobId = sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE)); verify(sparkJobClient, times(1)) .startJobRun( - QUERY, - "flint-opensearch-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString()); + new StartJobRequest( + query, + "my_glue-default-http_logs", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); Assertions.assertEquals(EMR_JOB_ID, jobId); } + @Test + void testDispatchIndexQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("table", "http_logs"); + tags.put("index", "elb_and_requestUri"); + String query = + "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + when(sparkJobClient.startJobRun( + new StartJobRequest( + query, + "my_glue-default-http_logs-elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE)); + verify(sparkJobClient, times(1)) + .startJobRun( + new StartJobRequest( + query, + "my_glue-default-http_logs-elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); + Assertions.assertEquals(EMR_JOB_ID, jobId); + } + + @Test + void testDispatchWithPPLQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "select * from my_glue.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.PPL, EMRS_EXECUTION_ROLE))); + Assertions.assertEquals( + "UnSupported Lang type:: PPL", unsupportedOperationException.getMessage()); + verifyNoInteractions(sparkJobClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchQueryWithoutATableName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "show tables"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE))); + Assertions.assertEquals( + "Queries without a datasource are not supported", + unsupportedOperationException.getMessage()); + verifyNoInteractions(sparkJobClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchQueryWithoutADataSourceName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = "select * from default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE))); + Assertions.assertEquals( + "Queries without a datasource are not supported", + unsupportedOperationException.getMessage()); + verifyNoInteractions(sparkJobClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + + @Test + void testDispatchIndexQueryWithoutADatasourceName() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + String query = + "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE))); + Assertions.assertEquals( + "Queries without a datasource are not supported", + unsupportedOperationException.getMessage()); + verifyNoInteractions(sparkJobClient); + verifyNoInteractions(dataSourceService); + verifyNoInteractions(dataSourceUserAuthorizationHelper); + verifyNoInteractions(jobExecutionResponseReader); + } + @Test void testDispatchWithWrongURI() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); when(dataSourceService.getRawDataSourceMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); + String query = "select * from my_glue.default.http_logs"; IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> sparkQueryDispatcher.dispatch(EMRS_APPLICATION_ID, QUERY, EMRS_EXECUTION_ROLE)); + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE))); Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", illegalArgumentException.getMessage()); } + @Test + void testDispatchWithUnSupportedDataSourceType() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); + when(dataSourceService.getRawDataSourceMetadata("my_prometheus")) + .thenReturn(constructPrometheusDataSourceType()); + String query = "select * from my_prometheus.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE))); + Assertions.assertEquals( + "UnSupported datasource type for async queries:: PROMETHEUS", + unsupportedOperationException.getMessage()); + } + @Test void testCancelJob() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); when(sparkJobClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() @@ -96,7 +296,11 @@ void testCancelJob() { @Test void testGetQueryResponse() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); @@ -107,7 +311,11 @@ void testGetQueryResponse() { @Test void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + new SparkQueryDispatcher( + sparkJobClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader); when(sparkJobClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); @@ -185,4 +393,13 @@ private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { dataSourceMetadata.setProperties(properties); return dataSourceMetadata; } + + private DataSourceMetadata constructPrometheusDataSourceType() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_prometheus"); + dataSourceMetadata.setConnector(DataSourceType.PROMETHEUS); + Map properties = new HashMap<>(); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index 6596a9e820..ef49d29829 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -27,6 +27,7 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -56,7 +57,7 @@ public void setUp() { @Test public void testDoExecute() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) @@ -72,7 +73,7 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "sql"); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); doThrow(new RuntimeException("Error")) diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java new file mode 100644 index 0000000000..0ed80781de --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +@ExtendWith(MockitoExtension.class) +public class SQLQueryUtilsTest { + + @Test + void testExtractionOfTableNameFromSQLQueries() { + String sqlQuery = "select * from my_glue.default.http_logs"; + FullyQualifiedTableName fullyQualifiedTableName = + SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from my_glue.db.http_logs"; + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("db", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from my_glue.http_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("my_glue", fullyQualifiedTableName.getSchemaName()); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "select * from http_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "DROP TABLE myS3.default.alb_logs"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = + "CREATE EXTERNAL TABLE\n" + + "myS3.default.alb_logs\n" + + "[ PARTITIONED BY (col_name [, … ] ) ]\n" + + "[ ROW FORMAT DELIMITED row_format ]\n" + + "STORED AS file_format\n" + + "LOCATION { 's3://bucket/folder/' }"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertFalse(SQLQueryUtils.isIndexQuery(sqlQuery)); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + + sqlQuery = "SHOW tables"; + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); + Assertions.assertNull(fullyQualifiedTableName); + } + + @Test + void testExtractionFromFlintIndexQueries() { + String createCoveredIndexQuery = + "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = true)"; + Assertions.assertTrue(SQLQueryUtils.isIndexQuery(createCoveredIndexQuery)); + IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + Assertions.assertEquals("elb_and_requestUri", indexDetails.getIndexName()); + Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); + Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); + Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + } +}