Skip to content

Commit

Permalink
[SPARK-49552][CONNECT][FOLLOW-UP] Make 'randstr' and 'uniform' determ…
Browse files Browse the repository at this point in the history
…inistic in Scala Client

### What changes were proposed in this pull request?
Make 'randstr' and 'uniform' deterministic in Scala Client

### Why are the changes needed?
We need to explicitly set the seed in connect clients, to avoid making the output dataframe non-deterministic (see 14ba4fc)

When reviewing #48143, I requested the author to set the seed in python client.
But at that time, I was not aware of the fact that Spark Connect Scala Client was reusing the same `functions.scala` under `org.apache.spark.sql`. (There were two different files before)

So the two functions may cause non-deterministic issues like:
```
scala> val df = spark.range(10).select(randstr(lit(10)).as("r"))
Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties
df: org.apache.spark.sql.package.DataFrame = [r: string]

scala> df.show()
+----------+
|         r|
+----------+
|5bhIk72PJa|
|tuhC50Di38|
|PxwfWzdT3X|
|sWkmSyWboh|
|uZMS4htmM0|
|YMxMwY5wdQ|
|JDaWSiBwDD|
|C7KQ20WE7t|
|IwSSqWOObg|
|jDF2Ndfy8q|
+----------+

scala> df.show()
+----------+
|         r|
+----------+
|fpnnoLJbOA|
|qerIKpYPif|
|PvliXYIALD|
|xK3fosAvOp|
|WK12kfkPXq|
|2UcdyAEbNm|
|HEkl4rMtV1|
|PCaH4YJuYo|
|JuuXEHSp5i|
|jSLjl8ug8S|
+----------+
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
after this fix:
```
scala> val df = spark.range(10).select(randstr(lit(10)).as("r"))
df: org.apache.spark.sql.package.DataFrame = [r: string]

scala> df.show()
+----------+
|         r|
+----------+
|Gri9B9X8zI|
|gfhpGD8PcV|
|FDaXofTzlN|
|p7ciOScWpu|
|QZiEbF5q7c|
|9IhRoXmTUM|
|TeSEG1EKSN|
|B7nLw5iedL|
|uFZo1WPLPT|
|46E2LVCxxl|
+----------+

scala> df.show()
+----------+
|         r|
+----------+
|Gri9B9X8zI|
|gfhpGD8PcV|
|FDaXofTzlN|
|p7ciOScWpu|
|QZiEbF5q7c|
|9IhRoXmTUM|
|TeSEG1EKSN|
|B7nLw5iedL|
|uFZo1WPLPT|
|46E2LVCxxl|
+----------+
```

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48558 from zhengruifeng/sql_rand_str_seed.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Oct 22, 2024
1 parent 91ae102 commit 70c9b1f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1901,7 +1901,8 @@ object functions {
* @group string_funcs
* @since 4.0.0
*/
def randstr(length: Column): Column = Column.fn("randstr", length)
def randstr(length: Column): Column =
randstr(length, lit(SparkClassUtils.random.nextLong))

/**
* Returns a string of the specified length whose characters are chosen uniformly at random from
Expand Down Expand Up @@ -3767,7 +3768,8 @@ object functions {
* @group math_funcs
* @since 4.0.0
*/
def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max)
def uniform(min: Column, max: Column): Column =
uniform(min, max, lit(SparkClassUtils.random.nextLong))

/**
* Returns a random value with independent and identically distributed (i.i.d.) values with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
// Here we exercise some error cases.
val df = Seq((0)).toDF("a")
var expr = uniform(lit(10), lit("a"))
var expr = uniform(lit(10), lit("a"), lit(1))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"uniform(10, a)\"",
"sqlExpr" -> "\"uniform(10, a, 1)\"",
"paramIndex" -> "second",
"inputSql" -> "\"a\"",
"inputType" -> "\"STRING\"",
Expand All @@ -525,15 +525,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = uniform(col("a"), lit(10))
expr = uniform(col("a"), lit(10), lit(1))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
parameters = Map(
"inputName" -> "`min`",
"inputType" -> "integer or floating-point",
"inputExpr" -> "\"a\"",
"sqlExpr" -> "\"uniform(a, 10)\""),
"sqlExpr" -> "\"uniform(a, 10, 1)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "uniform",
Expand Down

0 comments on commit 70c9b1f

Please sign in to comment.