Skip to content

Commit

Permalink
Restrict UDF functions (#2790)
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
(cherry picked from commit db2bd66)
  • Loading branch information
vmmusings committed Aug 1, 2024
1 parent a87e13c commit 1cab38a
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 31 deletions.
1 change: 1 addition & 0 deletions async-query-core/src/main/antlr/SqlBaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ NANOSECOND: 'NANOSECOND';
NANOSECONDS: 'NANOSECONDS';
NATURAL: 'NATURAL';
NO: 'NO';
NONE: 'NONE';
NOT: 'NOT';
NULL: 'NULL';
NULLS: 'NULLS';
Expand Down
52 changes: 39 additions & 13 deletions async-query-core/src/main/antlr/SqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ singleCompoundStatement
;

beginEndCompoundBlock
: BEGIN compoundBody END
: beginLabel? BEGIN compoundBody END endLabel?
;

compoundBody
Expand All @@ -61,11 +61,26 @@ compoundBody

compoundStatement
: statement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
;

setStatementWithOptionalVarKeyword
: SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword
| SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
;

singleStatement
: statement SEMICOLON* EOF
: (statement|setResetStatement) SEMICOLON* EOF
;

beginLabel
: multipartIdentifier COLON
;

endLabel
: multipartIdentifier
;

singleExpression
Expand Down Expand Up @@ -175,6 +190,8 @@ statement
| ALTER TABLE identifierReference
(partitionSpec)? SET locationSpec #setTableLocation
| ALTER TABLE identifierReference RECOVER PARTITIONS #recoverPartitions
| ALTER TABLE identifierReference
(clusterBySpec | CLUSTER BY NONE) #alterClusterBy
| DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable
| DROP VIEW (IF EXISTS)? identifierReference #dropView
| CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
Expand Down Expand Up @@ -203,7 +220,7 @@ statement
identifierReference dataType? variableDefaultExpression? #createVariable
| DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable
| EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
statement #explain
(statement|setResetStatement) #explain
| SHOW TABLES ((FROM | IN) identifierReference)?
(LIKE? pattern=stringLit)? #showTables
| SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)?
Expand Down Expand Up @@ -242,26 +259,29 @@ statement
| (MSCK)? REPAIR TABLE identifierReference
(option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable
| op=(ADD | LIST) identifier .*? #manageResource
| SET COLLATION collationName=identifier #setCollation
| SET ROLE .*? #failNativeCommand
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
identifierReference (USING indexType=identifier)?
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

setResetStatement
: SET COLLATION collationName=identifier #setCollation
| SET ROLE .*? #failSetRole
| SET TIME ZONE interval #setTimeZone
| SET TIME ZONE timezone #setTimeZone
| SET TIME ZONE .*? #setTimeZone
| SET (VARIABLE | VAR) assignmentList #setVariable
| SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariable
LEFT_PAREN query RIGHT_PAREN #setVariable
| SET configKey EQ configValue #setQuotedConfiguration
| SET configKey (EQ .*?)? #setConfiguration
| SET .*? EQ configValue #setQuotedConfiguration
| SET .*? #setConfiguration
| RESET configKey #resetQuotedConfiguration
| RESET .*? #resetConfiguration
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
identifierReference (USING indexType=identifier)?
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

executeImmediate
Expand Down Expand Up @@ -854,13 +874,17 @@ identifierComment

relationPrimary
: identifierReference temporalClause?
sample? tableAlias #tableName
optionsClause? sample? tableAlias #tableName
| LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery
| LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation
| inlineTable #inlineTableDefault2
| functionTable #tableValuedFunction
;

optionsClause
: WITH options=propertyList
;

inlineTable
: VALUES expression (COMMA expression)* tableAlias
;
Expand Down Expand Up @@ -1573,6 +1597,7 @@ ansiNonReserved
| NANOSECOND
| NANOSECONDS
| NO
| NONE
| NULLS
| NUMERIC
| OF
Expand Down Expand Up @@ -1921,6 +1946,7 @@ nonReserved
| NANOSECOND
| NANOSECONDS
| NO
| NONE
| NOT
| NULL
| NULLS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -45,25 +46,50 @@ public DispatchQueryResponse dispatch(
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource(), asyncQueryRequestContext);

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();

if (SQLQueryUtils.isFlintExtensionQuery(query)) {
return handleFlintExtensionQuery(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors = SQLQueryUtils.validateSparkSqlQuery(query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
}
}
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

private DispatchQueryResponse handleFlintExtensionQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
}

private DispatchQueryResponse handleDefaultQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {

DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}

private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -73,6 +75,32 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(String sqlQuery) {
SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor();
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
sparkSqlValidatorVisitor.visit(statement);
return sparkSqlValidatorVisitor.getValidationErrors();
} catch (SyntaxCheckException syntaxCheckException) {
return Collections.emptyList();
}
}

private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private final List<String> validationErrors = new ArrayList<>();

@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
validationErrors.add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.JobRunState;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -447,6 +449,77 @@ void testDispatchWithPPLQuery() {
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchWithSparkUDFQuery() {
List<String> udfQueries = new ArrayList<>();
udfQueries.add(
"CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *"
+ " 9/5) + 32\")'");
udfQueries.add(
"CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'");
for (String query : udfQueries) {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class,
() ->
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(),
asyncQueryRequestContext));
Assertions.assertEquals(
"Query is not allowed: Creating user-defined functions is not allowed",
illegalArgumentException.getMessage());
verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(flintIndexMetadataService);
}
}

@Test
void testInvalidSQLQueryDispatchToSpark() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "myselect 1";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
null,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
sparkSubmitParameters,
tags,
false,
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
DispatchQueryRequest.builder()
.applicationId(EMRS_APPLICATION_ID)
.query(query)
.datasource(MY_GLUE)
.langType(LangType.SQL)
.executionRoleARN(EMRS_EXECUTION_ROLE)
.clusterName(TEST_CLUSTER_NAME)
.sparkSubmitParameterModifier(sparkSubmitParameterModifier)
.build(),
asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,25 @@ void testAutoRefresh() {
.autoRefresh());
}

@Test
void testValidateSparkSqlQuery_ValidQuery() {
String validQuery = "SELECT * FROM users WHERE age > 18";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(validQuery);
assertTrue(errors.isEmpty(), "Valid query should not produce any errors");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery() {
String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery);
assertFalse(errors.isEmpty(), "Invalid query should produce errors");
assertEquals(1, errors.size(), "Should have one error");
assertEquals(
"Creating user-defined functions is not allowed",
errors.get(0),
"Error message should match");
}

@Getter
protected static class IndexQuery {
private String query;
Expand Down
2 changes: 2 additions & 0 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Async Query Creation API
======================================
If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``.

Limitation: Spark SQL queries that create User-Defined Functions (UDFs) are not allowed.

HTTP URI: ``_plugins/_async_query``

HTTP VERB: ``POST``
Expand Down

0 comments on commit 1cab38a

Please sign in to comment.