Skip to content

Commit

Permalink
[SPARK-50130][SQL][FOLLOWUP] Make Encoder generation lazy
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Makes Encoder generation lazy.

### Why are the changes needed?

The encoder with empty schema for lazy plan could cause unexpected behavior.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48829 from ueshin/issues/SPARK-50130/lazy_encoder.

Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
  • Loading branch information
ueshin committed Nov 20, 2024
1 parent 30d0b01 commit ad46db4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 27 deletions.
35 changes: 13 additions & 22 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,8 @@ private[sql] object Dataset {
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
sparkSession.withActive {
val qe = sparkSession.sessionState.executePlan(logicalPlan)
val encoder = if (qe.isLazyAnalysis) {
RowEncoder.encoderFor(new StructType())
} else {
qe.assertAnalyzed()
RowEncoder.encoderFor(qe.analyzed.schema)
}
new Dataset[Row](qe, encoder)
if (!qe.isLazyAnalysis) qe.assertAnalyzed()
new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
}

def ofRows(
Expand All @@ -111,13 +106,8 @@ private[sql] object Dataset {
sparkSession.withActive {
val qe = new QueryExecution(
sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
val encoder = if (qe.isLazyAnalysis) {
RowEncoder.encoderFor(new StructType())
} else {
qe.assertAnalyzed()
RowEncoder.encoderFor(qe.analyzed.schema)
}
new Dataset[Row](qe, encoder)
if (!qe.isLazyAnalysis) qe.assertAnalyzed()
new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
}

/** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
Expand All @@ -129,13 +119,8 @@ private[sql] object Dataset {
: DataFrame = sparkSession.withActive {
val qe = new QueryExecution(
sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode)
val encoder = if (qe.isLazyAnalysis) {
RowEncoder.encoderFor(new StructType())
} else {
qe.assertAnalyzed()
RowEncoder.encoderFor(qe.analyzed.schema)
}
new Dataset[Row](qe, encoder)
if (!qe.isLazyAnalysis) qe.assertAnalyzed()
new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
}
}

Expand Down Expand Up @@ -229,7 +214,7 @@ private[sql] object Dataset {
@Stable
class Dataset[T] private[sql](
@DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
@DeveloperApi @Unstable @transient val encoder: Encoder[T])
@transient encoderGenerator: () => Encoder[T])
extends api.Dataset[T] {
type DS[U] = Dataset[U]

Expand All @@ -252,6 +237,10 @@ class Dataset[T] private[sql](
// Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.

private[sql] def this(queryExecution: QueryExecution, encoder: Encoder[T]) = {
this(queryExecution, () => encoder)
}

def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
this(sparkSession.sessionState.executePlan(logicalPlan), encoder)
}
Expand All @@ -274,6 +263,8 @@ class Dataset[T] private[sql](
}
}

@DeveloperApi @Unstable @transient lazy val encoder: Encoder[T] = encoderGenerator()

/**
* Expose the encoder as implicit so it can be used to construct new Dataset objects that have
* the same external type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,18 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession {
}

test("unanalyzable expression") {
val exception = intercept[AnalysisException] {
spark.range(1).select($"id" === $"id".outer()).schema
}
val sub = spark.range(1).select($"id" === $"id".outer())

checkError(
intercept[AnalysisException](sub.schema),
condition = "UNANALYZABLE_EXPRESSION",
parameters = Map("expr" -> "\"outer(id)\""),
queryContext =
Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern))
)

checkError(
exception,
intercept[AnalysisException](sub.encoder),
condition = "UNANALYZABLE_EXPRESSION",
parameters = Map("expr" -> "\"outer(id)\""),
queryContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
dt
)
checkError(
intercept[AnalysisException](spark.range(1).select(f())),
intercept[AnalysisException](spark.range(1).select(f()).encoder),
condition = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER",
sqlState = "0A000",
parameters = Map("dataType" -> s"\"${dt.sql}\"")
Expand Down

0 comments on commit ad46db4

Please sign in to comment.