Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44860][SQL] Add SESSION_USER function #42549

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -3346,6 +3346,14 @@ object functions {
*/
def user(): Column = Column.fn("user")

/**
* Returns the user name of current execution context.
*
* @group misc_funcs
* @since 4.0.0
*/
def session_user(): Column = Column.fn("session_user")

/**
* Returns an universally unique identifier (UUID) string. The value is returned as a canonical
* UUID 36-character string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,11 @@ class PlanGenerationTestSuite
fn.user()
}

functionTest("session_user") {
fn.session_user()
}


functionTest("md5") {
fn.md5(fn.col("g").cast("binary"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ datetimeUnit
;

primaryExpression
: name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER) #currentLike
: name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER | SESSION_USER) #currentLike
| name=(TIMESTAMPADD | DATEADD | DATE_ADD) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA unitsAmount=valueExpression COMMA timestamp=valueExpression RIGHT_PAREN #timestampadd
| name=(TIMESTAMPDIFF | DATEDIFF | DATE_DIFF | TIMEDIFF) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA startTimestamp=valueExpression COMMA endTimestamp=valueExpression RIGHT_PAREN #timestampdiff
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ trait ColumnResolutionHelper extends Logging {
}
}

// support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, SESSION_USER and grouping__id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, SESSION_USER and grouping__id
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, USER, SESSION_USER and grouping__id

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
("user", () => CurrentUser(), toPrettySQL),
("session_user", () => CurrentUser(), toPrettySQL),
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ object FunctionRegistry {
expression[CurrentCatalog]("current_catalog"),
expression[CurrentUser]("current_user"),
expression[CurrentUser]("user", setAlias = true),
expression[CurrentUser]("session_user", setAlias = true),
expression[CallMethodViaReflection]("reflect"),
expression[CallMethodViaReflection]("java_method", true),
expression[SparkVersion]("version"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
CurrentDate()
case SqlBaseParser.CURRENT_TIMESTAMP =>
CurrentTimestamp()
case SqlBaseParser.CURRENT_USER | SqlBaseParser.USER =>
case SqlBaseParser.CURRENT_USER | SqlBaseParser.USER | SqlBaseParser.SESSION_USER =>
CurrentUser()
}
} else {
Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3370,6 +3370,14 @@ object functions {
*/
def user(): Column = withExpr { CurrentUser() }

/**
* Returns the user name of current execution context.
*
* @group misc_funcs
* @since 4.0.0
*/
def session_user(): Column = withExpr { CurrentUser() }

/**
* Returns an universally unique identifier (UUID) string. The value is returned as a canonical
* UUID 36-character string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5878,11 +5878,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.selectExpr("CURRENT_SCHEMA()"), df.select(current_schema()))
}

test("function current_user, user") {
test("function current_user, user, session_user") {
val df = Seq((1, 2), (3, 1)).toDF("a", "b")

checkAnswer(df.selectExpr("CURRENT_USER()"), df.select(current_user()))
checkAnswer(df.selectExpr("USER()"), df.select(user()))
checkAnswer(df.selectExpr("SESSION_USER()"), df.select(session_user()))
}

test("named_struct function") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.selectExpr("version()"), df.select(version()))
}

test("SPARK-21957: get current_user in normal spark apps") {
test("SPARK-21957, SPARK-44860: get current_user, session_user in normal spark apps") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update SPARK-21957: get current_user through thrift server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

val user = spark.sparkContext.sparkUser
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
val df = sql("select current_user(), current_user, user, user()")
checkAnswer(df, Row(user, user, user, user))
val df =
sql("select current_user(), current_user, user, user(), session_user(), session_user")
checkAnswer(df, Row(user, user, user, user, user, user))
}
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true",
SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") {
Seq("user", "current_user").foreach { func =>
Seq("user", "current_user", "session_user").foreach { func =>
checkAnswer(sql(s"select $func"), Row(user))
checkError(
exception = intercept[ParseException](sql(s"select $func()")),
Expand Down