Skip to content

Commit

Permalink
[SPARK-49755][CONNECT] Remove special casing for avro functions in Co…
Browse files Browse the repository at this point in the history
…nnect

### What changes were proposed in this pull request?
apply the built-in registered functions

### Why are the changes needed?
code simplification

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
updated tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48209 from zhengruifeng/connect_avro.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
  • Loading branch information
zhengruifeng authored and LuciferYang committed Sep 23, 2024
1 parent e1637e3 commit fec1562
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ case class FromAvro(child: Expression, jsonFormatSchema: Expression, options: Ex
override def second: Expression = jsonFormatSchema
override def third: Expression = options

def this(child: Expression, jsonFormatSchema: Expression) =
this(child, jsonFormatSchema, Literal.create(null))

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = {
copy(child = newFirst, jsonFormatSchema = newSecond, options = newThird)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes)#0]
Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes, {"type": "int", "name": "id"}, map(mode, FAILFAST, compression, zstandard))#0]
+- LocalRelation <empty>, [id#0L, bytes#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes)#0]
Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes, {"type": "string", "name": "name"}, NULL)#0]
+- LocalRelation <empty>, [id#0L, bytes#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a)#0]
Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a, {"type": "int", "name": "id"})#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [to_avro(id#0L, None) AS to_avro(id)#0]
Project [to_avro(id#0L, None) AS to_avro(id, NULL)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
2 changes: 1 addition & 1 deletion sql/connect/server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-avro_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder}
Expand Down Expand Up @@ -1523,8 +1522,7 @@ class SparkConnectPlanner(
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
transformUnresolvedAttribute(exp.getUnresolvedAttribute)
case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
transformUnregisteredFunction(exp.getUnresolvedFunction)
.getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
transformUnresolvedFunction(exp.getUnresolvedFunction)
case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
transformExpressionString(exp.getExpressionString)
Expand Down Expand Up @@ -1844,49 +1842,6 @@ class SparkConnectPlanner(
UnresolvedNamedLambdaVariable(variable.getNamePartsList.asScala.toSeq)
}

/**
* For some reason, not all functions are registered in 'FunctionRegistry'. For a unregistered
* function, we can still wrap it under the proto 'UnresolvedFunction', and then resolve it in
* this method.
*/
private def transformUnregisteredFunction(
fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
fun.getFunctionName match {
// Avro-specific functions
case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val jsonFormatSchema = extractString(children(1), "jsonFormatSchema")
var options = Map.empty[String, String]
if (fun.getArgumentsCount == 3) {
options = extractMapData(children(2), "Options")
}
Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))

case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
var jsonFormatSchema = Option.empty[String]
if (fun.getArgumentsCount == 2) {
jsonFormatSchema = Some(extractString(children(1), "jsonFormatSchema"))
}
Some(CatalystDataToAvro(children.head, jsonFormatSchema))

case _ => None
}
}

private def extractString(expr: Expression, field: String): String = expr match {
case Literal(s, StringType) if s != null => s.toString
case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other")
}

@scala.annotation.tailrec
private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match {
case map: CreateMap => ExprUtils.convertToMapData(map)
case UnresolvedFunction(Seq("map"), args, _, _, _, _, _) =>
extractMapData(CreateMap(args), field)
case other => throw InvalidPlanInput(s"$field should be created by map, but got $other")
}

private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
if (alias.getNameCount == 1) {
val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {
Expand Down

0 comments on commit fec1562

Please sign in to comment.