Skip to content

Commit

Permalink
[SPARK-48459][CONNECT][PYTHON] Implement DataFrameQueryContext in Spa…
Browse files Browse the repository at this point in the history
…rk Connect

### What changes were proposed in this pull request?

This PR proposes to Implement DataFrameQueryContext in Spark Connect.

1.  Add two new protobuf messages packed together with `Expression`:

    ```proto
    message Origin {
      // (Required) Indicate the origin type.
      oneof function {
        PythonOrigin python_origin = 1;
      }
    }

    message PythonOrigin {
      // (Required) Name of the origin, for example, the name of the function
      string fragment = 1;

      // (Required) Callsite to show to end users, for example, stacktrace.
      string call_site = 2;
    }
    ```

2. Merge `DataFrameQueryContext.pysparkFragment` and `DataFrameQueryContext.pysparkcallSite` to existing `DataFrameQueryContext.fragment` and `DataFrameQueryContext.callSite`

3. Separate `QueryContext` into `SQLQueryContext` and `DataFrameQueryContext` for consistency w/ Scala side

4. Implement the origin logic. `current_origin` thread local holds the current call site/the function name, and `Expression` gets it from it.
    They are set to individual expression messages, and are used when analysis happens - this resembles Spark SQL implementation.

See also #45377.

### Why are the changes needed?

See #45377

### Does this PR introduce _any_ user-facing change?

Yes, same as #45377 but in Spark Connect.

### How was this patch tested?

Same unittests reused in Spark Connect.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46789 from HyukjinKwon/connect-context.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Jun 18, 2024
1 parent 58701d8 commit 80bba44
Show file tree
Hide file tree
Showing 19 changed files with 463 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ message ResourceProfile {
// (e.g., cores, memory, CPU) to its specific request.
map<string, TaskResourceRequest> task_resources = 2;
}

message Origin {
// (Required) Indicate the origin type.
oneof function {
PythonOrigin python_origin = 1;
}
}

message PythonOrigin {
// (Required) Name of the origin, for example, the name of the function
string fragment = 1;

// (Required) Callsite to show to end users, for example, stacktrace.
string call_site = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ syntax = 'proto3';

import "google/protobuf/any.proto";
import "spark/connect/types.proto";
import "spark/connect/common.proto";

package spark.connect;

Expand All @@ -30,6 +31,7 @@ option go_package = "internal/generated";
// expressions in SQL appear.
message Expression {

ExpressionCommon common = 18;
oneof expr_type {
Literal literal = 1;
UnresolvedAttribute unresolved_attribute = 2;
Expand Down Expand Up @@ -342,6 +344,11 @@ message Expression {
}
}

message ExpressionCommon {
// (Required) Keep the information of the origin for this expression such as stacktrace.
Origin origin = 1;
}

message CommonInlineUserDefinedFunction {
// (Required) Name of the user-defined function.
string function_name = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ message Unknown {}

// Common metadata of all relations.
message RelationCommon {
// TODO(SPARK-48639): Add origin like Expression.ExpressionCommon

// (Required) Shared relation metadata.
string source_info = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
Expand All @@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
Expand Down Expand Up @@ -1471,7 +1472,21 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
def transformExpression(exp: proto.Expression): Expression = {
def transformExpression(exp: proto.Expression): Expression = if (exp.hasCommon) {
try {
val origin = exp.getCommon.getOrigin
PySparkCurrentOrigin.set(
origin.getPythonOrigin.getFragment,
origin.getPythonOrigin.getCallSite)
withOrigin { doTransformExpression(exp) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
doTransformExpression(exp)
}

private def doTransformExpression(exp: proto.Expression): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods

import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.{Logging, MDC}
Expand Down Expand Up @@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
sparkThrowableBuilder.addQueryContexts(
FetchErrorDetailsResponse.QueryContext
.newBuilder()
val builder = FetchErrorDetailsResponse.QueryContext
.newBuilder()
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
.setObjectType(queryCtx.objectType())
.setObjectName(queryCtx.objectName())
.setStartIndex(queryCtx.startIndex())
.setStopIndex(queryCtx.stopIndex())
.setFragment(queryCtx.fragment())
.build())
.setSummary(queryCtx.summary())
.build()
} else {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
.setFragment(queryCtx.fragment())
.setCallSite(queryCtx.callSite())
.setSummary(queryCtx.summary())
.build()
}
sparkThrowableBuilder.addQueryContexts(context)
}
if (sparkThrowable.getSqlState != null) {
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
Expand Down
51 changes: 37 additions & 14 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
return [QueryContext(q) for q in self._origin.getQueryContext()]
contexts: List[BaseQueryContext] = []
for q in self._origin.getQueryContext():
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
contexts.append(DataFrameQueryContext(q))

return contexts
else:
return []

Expand Down Expand Up @@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.contextType().toString()
assert context_type in ("SQL", "DataFrame")
if context_type == "DataFrame":
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.objectType())
Expand All @@ -409,13 +411,34 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())
def summary(self) -> str:
return str(self._q.summary())


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
return str(self._q.objectType())

def objectName(self) -> str:
return str(self._q.objectName())

def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())
def startIndex(self) -> int:
return int(self._q.startIndex())

def stopIndex(self) -> int:
return int(self._q.stopIndex())

def fragment(self) -> str:
return str(self._q.fragment())

def callSite(self) -> str:
return str(self._q.callSite())

def summary(self) -> str:
return str(self._q.summary())
83 changes: 75 additions & 8 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def convert_exception(
)
query_contexts = []
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
query_contexts.append(QueryContext(query_context))
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
query_contexts.append(SQLQueryContext(query_context))
else:
query_contexts.append(DataFrameQueryContext(query_context))

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
Expand Down Expand Up @@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.context_type

if int(context_type) == QueryContextType.DataFrame.value:
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.object_type)
Expand All @@ -457,6 +455,75 @@ def stopIndex(self) -> int:
def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def summary(self) -> str:
return str(self._q.summary)


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def objectName(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def startIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def stopIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
return str(self._q.call_site)

Expand Down
Loading

0 comments on commit 80bba44

Please sign in to comment.