Skip to content

Commit

Permalink
[SPARK-42688][CONNECT] Rename Connect proto Request client_id to sess…
Browse files Browse the repository at this point in the history
…ion_id

### What changes were proposed in this pull request?

Rename Connect proto requests `client_id` to `session_id`.

On the one hand when I read `client_id` I was confused on what it is used to, even after reading the proto documentation.

On the other hand,  client sides already use session_id:
https://github.com/apache/spark/blob/9bf174f9722e34f13bfaede5e59f989bf2a511e9/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala#L51
https://github.com/apache/spark/blob/9bf174f9722e34f13bfaede5e59f989bf2a511e9/python/pyspark/sql/connect/client.py#L522

### Why are the changes needed?

Code readability

### Does this PR introduce _any_ user-facing change?

NO

### How was this patch tested?

Existing UT

Closes apache#40309 from amaliujia/update_client_id.

Authored-by: Rui Wang <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
amaliujia authored and hvanhovell committed Mar 7, 2023
1 parent d4818df commit dfdc4a1
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private[sql] class SparkConnectClient(
.newBuilder()
.setPlan(plan)
.setUserContext(userContext)
.setClientId(sessionId)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
stub.executePlan(request)
Expand All @@ -78,7 +78,7 @@ private[sql] class SparkConnectClient(
val request = proto.ConfigRequest
.newBuilder()
.setOperation(operation)
.setClientId(sessionId)
.setSessionId(sessionId)
.setClientType(userAgent)
.setUserContext(userContext)
.build()
Expand Down Expand Up @@ -157,7 +157,7 @@ private[sql] class SparkConnectClient(
private def analyze(builder: proto.AnalyzePlanRequest.Builder): proto.AnalyzePlanResponse = {
val request = builder
.setUserContext(userContext)
.setClientId(sessionId)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
analyze(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ class ClientE2ETestSuite extends RemoteSparkSession {
}

test("SparkSession newSession") {
val oldId = spark.sql("SELECT 1").analyze.getClientId
val newId = spark.newSession().sql("SELECT 1").analyze.getClientId
val oldId = spark.sql("SELECT 1").analyze.getSessionId
val newId = spark.newSession().sql("SELECT 1").analyze.getSessionId
assert(oldId != newId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
client = clientBuilder(server.getPort)
val request = AnalyzePlanRequest
.newBuilder()
.setClientId("abc123")
.setSessionId("abc123")
.build()

val response = client.analyze(request)
assert(response.getClientId === "abc123")
assert(response.getSessionId === "abc123")
}

test("Test connection") {
Expand All @@ -99,7 +99,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
.connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
.build()

val request = AnalyzePlanRequest.newBuilder().setClientId("abc123").build()
val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()

// Failed the ssl handshake as the dummy server does not have any server credentials installed.
assertThrows[StatusRuntimeException] {
Expand Down Expand Up @@ -201,11 +201,11 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
request: ExecutePlanRequest,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
// Reply with a dummy response using the same client ID
val requestClientId = request.getClientId
val requestSessionId = request.getSessionId
inputPlan = request.getPlan
val response = ExecutePlanResponse
.newBuilder()
.setClientId(requestClientId)
.setSessionId(requestSessionId)
.build()
responseObserver.onNext(response)
responseObserver.onCompleted()
Expand All @@ -215,7 +215,7 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
request: AnalyzePlanRequest,
responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
// Reply with a dummy response using the same client ID
val requestClientId = request.getClientId
val requestSessionId = request.getSessionId
request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
inputPlan = request.getSchema.getPlan
Expand All @@ -233,7 +233,7 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
}
val response = AnalyzePlanResponse
.newBuilder()
.setClientId(requestClientId)
.setSessionId(requestSessionId)
.build()
responseObserver.onNext(response)
responseObserver.onCompleted()
Expand Down
41 changes: 25 additions & 16 deletions connector/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ message UserContext {
message AnalyzePlanRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// (Required) User context
UserContext user_context = 2;
Expand Down Expand Up @@ -161,7 +162,7 @@ message AnalyzePlanRequest {
// Response to performing analysis of the query. Contains relevant metadata to be able to
// reason about the performance.
message AnalyzePlanResponse {
string client_id = 1;
string session_id = 1;

oneof result {
Schema schema = 2;
Expand Down Expand Up @@ -217,11 +218,15 @@ message AnalyzePlanResponse {
message ExecutePlanRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
// server side.
UserContext user_context = 2;

// (Required) The logical plan to be executed / analyzed.
Expand All @@ -234,9 +239,9 @@ message ExecutePlanRequest {
}

// The response of a query, can be one or more for each request. Responses belonging to the
// same input query, carry the same `client_id`.
// same input query, carry the same `session_id`.
message ExecutePlanResponse {
string client_id = 1;
string session_id = 1;

// Union type for the different response messages.
oneof response_type {
Expand Down Expand Up @@ -304,9 +309,10 @@ message KeyValue {
message ConfigRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// (Required) User context
UserContext user_context = 2;
Expand Down Expand Up @@ -369,7 +375,7 @@ message ConfigRequest {

// Response to the config request.
message ConfigResponse {
string client_id = 1;
string session_id = 1;

// (Optional) The result key-value pairs.
//
Expand All @@ -386,9 +392,12 @@ message ConfigResponse {
// Request to transfer client-local artifacts.
message AddArtifactsRequest {

// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;
// (Required)
//
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// User context
UserContext user_context = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ class SparkConnectPlanner(val session: SparkSession) {

def process(
command: proto.Command,
clientId: String,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
Expand All @@ -1473,14 +1473,14 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension)
case proto.Command.CommandTypeCase.SQL_COMMAND =>
handleSqlCommand(command.getSqlCommand, clientId, responseObserver)
handleSqlCommand(command.getSqlCommand, sessionId, responseObserver)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}

def handleSqlCommand(
getSqlCommand: SqlCommand,
clientId: String,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
// Eagerly execute commands of the provided SQL string.
val df = session.sql(getSqlCommand.getSql, getSqlCommand.getArgsMap)
Expand Down Expand Up @@ -1537,12 +1537,12 @@ class SparkConnectPlanner(val session: SparkSession) {
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
.setClientId(clientId)
.setSessionId(sessionId)
.setSqlCommandResult(result)
.build())

// Send Metrics
SparkConnectStreamHandler.sendMetricsToResponse(clientId, df)
SparkConnectStreamHandler.sendMetricsToResponse(sessionId, df)
}

private def handleRegisterUserDefinedFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private[connect] class SparkConnectAnalyzeHandler(
def handle(request: proto.AnalyzePlanRequest): Unit = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId)
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.session
session.withActive {
val response = process(request, session)
Expand Down Expand Up @@ -155,7 +155,7 @@ private[connect] class SparkConnectAnalyzeHandler(
case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!")
}

builder.setClientId(request.getClientId)
builder.setSessionId(request.getSessionId)
builder.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes
def handle(request: proto.ConfigRequest): Unit = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId)
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.session

val builder = request.getOperation.getOpTypeCase match {
Expand All @@ -53,7 +53,7 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes
case _ => throw new UnsupportedOperationException(s"${request.getOperation} not supported.")
}

builder.setClientId(request.getClientId)
builder.setSessionId(request.getSessionId)
responseObserver.onNext(builder.build())
responseObserver.onCompleted()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
def handle(v: ExecutePlanRequest): Unit = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getClientId)
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
.session
session.withActive {
v.getPlan.getOpTypeCase match {
Expand All @@ -60,20 +60,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
processAsArrowBatches(request.getClientId, dataframe, responseObserver)
processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
responseObserver.onNext(
SparkConnectStreamHandler.sendMetricsToResponse(request.getClientId, dataframe))
SparkConnectStreamHandler.sendMetricsToResponse(request.getSessionId, dataframe))
if (dataframe.queryExecution.observedMetrics.nonEmpty) {
responseObserver.onNext(
SparkConnectStreamHandler.sendObservedMetricsToResponse(request.getClientId, dataframe))
SparkConnectStreamHandler.sendObservedMetricsToResponse(request.getSessionId, dataframe))
}
responseObserver.onCompleted()
}

private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
val command = request.getPlan.getCommand
val planner = new SparkConnectPlanner(session)
planner.process(command, request.getClientId, responseObserver)
planner.process(command, request.getSessionId, responseObserver)
responseObserver.onCompleted()
}
}
Expand All @@ -96,7 +96,7 @@ object SparkConnectStreamHandler {
}

def processAsArrowBatches(
clientId: String,
sessionId: String,
dataframe: DataFrame,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val spark = dataframe.sparkSession
Expand Down Expand Up @@ -173,7 +173,7 @@ object SparkConnectStreamHandler {
}

partition.foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
Expand All @@ -191,7 +191,7 @@ object SparkConnectStreamHandler {
// Make sure at least 1 batch will be sent.
if (numSent == 0) {
val bytes = ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(0L)
Expand All @@ -203,17 +203,17 @@ object SparkConnectStreamHandler {
}
}

def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = {
def sendMetricsToResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = {
// Send a last batch with the metrics
ExecutePlanResponse
.newBuilder()
.setClientId(clientId)
.setSessionId(sessionId)
.setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
.build()
}

def sendObservedMetricsToResponse(
clientId: String,
sessionId: String,
dataframe: DataFrame): ExecutePlanResponse = {
val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) =>
val cols = (0 until row.length).map(i => toConnectProtoValue(row(i)))
Expand All @@ -226,7 +226,7 @@ object SparkConnectStreamHandler {
// Prepare a response with the observed metrics.
ExecutePlanResponse
.newBuilder()
.setClientId(clientId)
.setSessionId(sessionId)
.addAllObservedMetrics(observedMetrics.asJava)
.build()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.newBuilder()
.setPlan(plan)
.setUserContext(context)
.setClientId("session")
.setSessionId("session")
.build()

// The observer is executed inside this thread. So
Expand Down
Loading

0 comments on commit dfdc4a1

Please sign in to comment.