Skip to content

Commit

Permalink
Spark: Fail if temp functions are used in views (apache#9675)
Browse files Browse the repository at this point in the history
  • Loading branch information
nastra authored Feb 25, 2024
1 parent 56da99b commit b3c68e5
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView
import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View
import org.apache.spark.sql.catalyst.plans.logical.views.ShowIcebergViews
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_FUNCTION
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.LookupCatalog
import scala.collection.mutable

/**
* ResolveSessionCatalog exits early for some v2 View commands,
* thus they are pre-substituted here and then handled in ResolveViews
*/
case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
Expand Down Expand Up @@ -83,6 +87,13 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
catalogManager.v1SessionCatalog.isTempView(nameParts)
}

private def isTempFunction(nameParts: Seq[String]): Boolean = {
if (nameParts.size > 1) {
return false
}
catalogManager.v1SessionCatalog.isTemporaryFunction(nameParts.asFunctionIdentifier)
}

private object ResolvedIdent {
def unapply(unresolved: UnresolvedIdentifier): Option[ResolvedIdentifier] = unresolved match {
case UnresolvedIdentifier(nameParts, true) if isTempView(nameParts) =>
Expand All @@ -102,20 +113,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
private def verifyTemporaryObjectsDontExist(
name: Identifier,
child: LogicalPlan): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tempViews = collectTemporaryViews(child)
tempViews.foreach { nameParts =>
throw new AnalysisException(
errorClass = "INVALID_TEMP_OBJ_REFERENCE",
messageParameters = Map(
"obj" -> "VIEW",
"objName" -> name.name(),
"tempObj" -> "VIEW",
"tempObjName" -> nameParts.quoted))
if (tempViews.nonEmpty) {
throw invalidRefToTempObject(name, tempViews.map(v => v.quoted).mkString("[", ", ", "]"), "view")
}

// TODO: check for temp function names
val tempFunctions = collectTemporaryFunctions(child)
if (tempFunctions.nonEmpty) {
throw invalidRefToTempObject(name, tempFunctions.mkString("[", ", ", "]"), "function")
}
}

private def invalidRefToTempObject(name: Identifier, tempObjectNames: String, tempObjectType: String) = {
new AnalysisException(String.format("Cannot create view %s that references temporary %s: %s",
name, tempObjectType, tempObjectNames))
}

/**
Expand Down Expand Up @@ -149,4 +160,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
None
}
}

/**
* Collect the names of all temporary functions.
*/
private def collectTemporaryFunctions(child: LogicalPlan): Seq[String] = {
val tempFunctions = new mutable.HashSet[String]()
child.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
case f @ UnresolvedFunction(nameParts, _, _, _, _) if isTempFunction(nameParts) =>
tempFunctions += nameParts.head
f
case e: SubqueryExpression =>
tempFunctions ++= collectTemporaryFunctions(e.plan)
e
}
tempFunctions.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,39 @@ public void readFromViewReferencingGlobalTempView() throws NoSuchTableException
.hasMessageContaining("cannot be found");
}

@Test
public void readFromViewReferencingTempFunction() throws NoSuchTableException {
insertRows(10);
String viewName = viewName("viewReferencingTempFunction");
String functionName = "test_avg";
String sql = String.format("SELECT %s(id) FROM %s", functionName, tableName);
sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

ViewCatalog viewCatalog = viewCatalog();
Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema();

// it wouldn't be possible to reference a TEMP FUNCTION if the view had been created via SQL,
// but this can't be prevented when using the API directly
viewCatalog
.buildView(TableIdentifier.of(NAMESPACE, viewName))
.withQuery("spark", sql)
.withDefaultNamespace(NAMESPACE)
.withDefaultCatalog(catalogName)
.withSchema(schema)
.create();

assertThat(sql(sql)).hasSize(1).containsExactly(row(5.5));

// reading from a view that references a TEMP FUNCTION shouldn't be possible
assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("The function")
.hasMessageContaining(functionName)
.hasMessageContaining("cannot be found");
}

@Test
public void readFromViewWithCTE() throws NoSuchTableException {
insertRows(10);
Expand Down Expand Up @@ -947,9 +980,9 @@ public void createViewReferencingTempView() throws NoSuchTableException {
assertThatThrownBy(
() -> sql("CREATE VIEW %s AS SELECT id FROM %s", viewReferencingTempView, tempView))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewReferencingTempView)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(
String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempView);
}

Expand All @@ -970,10 +1003,59 @@ public void createViewReferencingGlobalTempView() throws NoSuchTableException {
"CREATE VIEW %s AS SELECT id FROM global_temp.%s",
viewReferencingTempView, globalTempView))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewReferencingTempView)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(globalTempView);
.hasMessageContaining(
String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView));
}

@Test
public void createViewReferencingTempFunction() {
String viewName = viewName("viewReferencingTemporaryFunction");
String functionName = "test_avg_func";

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

// creating a view that references a TEMP FUNCTION shouldn't be possible
assertThatThrownBy(
() -> sql("CREATE VIEW %s AS SELECT %s(id) FROM %s", viewName, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
public void createViewReferencingQualifiedTempFunction() {
String viewName = viewName("viewReferencingTemporaryFunction");
String functionName = "test_avg_func_qualified";

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

// TEMP Function can't be referenced using catalog.schema.name
assertThatThrownBy(
() ->
sql(
"CREATE VIEW %s AS SELECT %s.%s.%s(id) FROM %s",
viewName, catalogName, NAMESPACE, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot resolve function")
.hasMessageContaining(
String.format("`%s`.`%s`.`%s`", catalogName, NAMESPACE, functionName));

// TEMP Function can't be referenced using schema.name
assertThatThrownBy(
() ->
sql(
"CREATE VIEW %s AS SELECT %s.%s(id) FROM %s",
viewName, NAMESPACE, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot resolve function")
.hasMessageContaining(String.format("`%s`.`%s`", NAMESPACE, functionName));
}

@Test
Expand Down Expand Up @@ -1118,12 +1200,32 @@ public void createViewWithCTEReferencingTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewName)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempViewInCTE);
}

@Test
public void createViewWithCTEReferencingTempFunction() {
String viewName = "viewWithCTEReferencingTempFunction";
String functionName = "avg_function_in_cte";
String sql =
String.format(
"WITH avg_data AS (SELECT %s(id) as avg FROM %s) "
+ "SELECT avg, count(1) AS count FROM avg_data GROUP BY max",
functionName, tableName);

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
public void createViewWithNonExistingQueryColumn() {
assertThatThrownBy(
Expand All @@ -1147,9 +1249,9 @@ public void createViewWithSubqueryExpressionUsingTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create the persistent object %s", viewName))
.hasMessageContaining(
String.format("because it references to the temporary object %s", tempView));
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempView);
}

@Test
Expand All @@ -1167,10 +1269,29 @@ public void createViewWithSubqueryExpressionUsingGlobalTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create the persistent object %s", viewName))
.hasMessageContaining(
String.format(
"because it references to the temporary object global_temp.%s", globalTempView));
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView));
}

@Test
public void createViewWithSubqueryExpressionUsingTempFunction() {
String viewName = viewName("viewWithSubqueryExpression");
String functionName = "avg_function_in_subquery";
String sql =
String.format(
"SELECT * FROM %s WHERE id < (SELECT %s(id) FROM %s)",
tableName, functionName, tableName);

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
Expand Down

0 comments on commit b3c68e5

Please sign in to comment.