diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 7fed175cbc8ea..2a39bc6bfddda 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -65,7 +65,6 @@ from pyspark.sql.types import ( _from_numpy_type, DataType, - LongType, StructType, ArrayType, StringType, @@ -2206,12 +2205,9 @@ def schema_of_xml(xml: Union[str, Column], options: Optional[Mapping[str, str]] schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__ -def shuffle(col: "ColumnOrName") -> Column: - return _invoke_function( - "shuffle", - _to_col(col), - LiteralExpression(random.randint(0, sys.maxsize), LongType()), - ) +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) + return _invoke_function("shuffle", _to_col(col), _seed) shuffle.__doc__ = pysparkfuncs.shuffle.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 5f8d1c21a24f1..2d5dbb5946050 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -17723,7 +17723,7 @@ def array_sort( @_try_remote_functions -def shuffle(col: "ColumnOrName") -> Column: +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: """ Array function: Generates a random permutation of the given array. @@ -17736,6 +17736,10 @@ def shuffle(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str The name of the column or expression to be shuffled. + seed : :class:`~pyspark.sql.Column` or int, optional + Seed value for the random generator. + + .. versionadded:: 4.0.0 Returns ------- @@ -17752,48 +17756,51 @@ def shuffle(col: "ColumnOrName") -> Column: Example 1: Shuffling a simple array >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, 3, 5],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +-------------+ - |shuffle(data)| - +-------------+ - |[1, 3, 20, 5]| - +-------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, 3, 5) AS data") + >>> df.select("*", sf.shuffle(df.data, sf.lit(123))).show() + +-------------+-------------+ + | data|shuffle(data)| + +-------------+-------------+ + |[1, 20, 3, 5]|[5, 1, 20, 3]| + +-------------+-------------+ Example 2: Shuffling an array with null values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, None, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +----------------+ - | shuffle(data)| - +----------------+ - |[20, 3, NULL, 1]| - +----------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, NULL, 5) AS data") + >>> df.select("*", sf.shuffle(sf.col("data"), 234)).show() + +----------------+----------------+ + | data| shuffle(data)| + +----------------+----------------+ + |[1, 20, NULL, 5]|[NULL, 5, 20, 1]| + +----------------+----------------+ Example 3: Shuffling an array with duplicate values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 2, 2, 3, 3, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[3, 2, 1, 3, 2, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data", 345)).show() + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[2, 3, 3, 1, 2, 3]| + +------------------+------------------+ - Example 4: Shuffling an array with different types of elements + Example 4: Shuffling an array with random seed >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([(['a', 'b', 'c', 1, 2, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[1, c, 2, a, b, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data")).show() # doctest: +SKIP + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[3, 3, 2, 3, 2, 1]| + +------------------+------------------+ """ - return _invoke_function_over_columns("shuffle", col) + if seed is not None: + return _invoke_function_over_columns("shuffle", col, lit(seed)) + else: + return _invoke_function_over_columns("shuffle", col) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 0662b8f2b271f..d9bceabe88f8f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -7252,7 +7252,18 @@ object functions { * @group array_funcs * @since 2.4.0 */ - def shuffle(e: Column): Column = Column.fn("shuffle", e, lit(SparkClassUtils.random.nextLong)) + def shuffle(e: Column): Column = shuffle(e, lit(SparkClassUtils.random.nextLong)) + + /** + * Returns a random permutation of the given array. + * + * @note + * The function is non-deterministic. + * + * @group array_funcs + * @since 4.0.0 + */ + def shuffle(e: Column, seed: Column): Column = Column.fn("shuffle", e, seed) /** * Returns a reversed string or an array with reverse order of elements.