Skip to content

Commit

Permalink
[SPARK-50851][ML][CONNECT][PYTHON] Express ML params with `proto.Expr…
Browse files Browse the repository at this point in the history
…ession.Literal`

### What changes were proposed in this pull request?
Express ML params with `proto.Expression.Literal`:
1, introduce `Literal.SpecializedArray` for large primitive literal arrays (e.g. the initial model coefficients which can be large)
```
    message SpecializedArray {
      oneof value_type {
        Bools bools = 1;
        Ints ints = 2;
        Longs longs = 3;
        Floats floats = 4;
        Doubles doubles = 5;
        Strings strings = 6;
      }

      message Bools {
        repeated bool values = 1;
      }

      message Ints {
        repeated int32 values = 1;
      }

      message Longs {
        repeated int64 values = 1;
      }

      message Floats {
        repeated float values = 1;
      }

      message Doubles {
        repeated double values = 1;
      }

      message Strings {
        repeated string values = 1;
      }
    }
```
2, Replace `proto.Param ` with `proto.Expression` to be consistent with SQL side
For `Param[Vector]` and `Param[Matrix]`, apply `proto.Expression.Literal.Struct` with the underlying schema of `VectorUDT` and `MatrixUDT`.

E.g. for `Param[Vector]` with value `Vectors.sparse(4, [(1, 1.0), (3, 5.5)])`, the message is like:
```
literal {
struct {
  struct_type {
    struct {
      ... <- schema of VectorUDT
    }
  }
  elements {
    byte: 0
  }
  elements {
    integer: 4
  }
  elements {
    specialized_array {
      ints {
        values: 1
        values: 3
      }
    }
  }
  elements {
    specialized_array {
      doubles {
        values: 1
        values: 5.5
      }
    }
  }
}

```

### Why are the changes needed?
1, to optimize large literal arrays, for both ML and SQL (we can apply it in SQL side later)
2, be consistent with SQL side, e.g. the parameterized SQL
```
  // (Optional) A map of parameter names to expressions.
  // It cannot coexist with `pos_arguments`.
  map<string, Expression.Literal> named_arguments = 4;

  // (Optional) A sequence of expressions for positional parameters in the SQL query text.
  // It cannot coexist with `named_arguments`.
  repeated Expression pos_arguments = 5;
```
3, to minimize the protobuf change

### Does this PR introduce _any_ user-facing change?
no, refactor-only

### How was this patch tested?
existing tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49529 from zhengruifeng/ml_proto_expr.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 17, 2025
1 parent 10dd350 commit 2721a50
Show file tree
Hide file tree
Showing 23 changed files with 906 additions and 871 deletions.
39 changes: 22 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,7 @@ import org.apache.spark.sql.types._
*/
private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
// set as not nullable, except values since in the future, support for binary matrices might
// be added for which values are not needed.
// the sparse matrix needs colPtrs and rowIndices, which are set as
// null, while building the dense matrix.
StructType(Array(
StructField("type", ByteType, nullable = false),
StructField("numRows", IntegerType, nullable = false),
StructField("numCols", IntegerType, nullable = false),
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
StructField("isTransposed", BooleanType, nullable = false)
))
}
override def sqlType: StructType = MatrixUDT.sqlType

override def serialize(obj: Matrix): InternalRow = {
val row = new GenericInternalRow(7)
Expand Down Expand Up @@ -108,3 +92,24 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

private[spark] override def asNullable: MatrixUDT = this
}

private[spark] object MatrixUDT {

val sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
// set as not nullable, except values since in the future, support for binary matrices might
// be added for which values are not needed.
// the sparse matrix needs colPtrs and rowIndices, which are set as
// null, while building the dense matrix.
StructType(Array(
StructField("type", ByteType, nullable = false),
StructField("numRows", IntegerType, nullable = false),
StructField("numCols", IntegerType, nullable = false),
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
StructField("isTransposed", BooleanType, nullable = false)
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {

override final def sqlType: StructType = _sqlType
override final def sqlType: StructType = VectorUDT.sqlType

override def serialize(obj: Vector): InternalRow = {
obj match {
Expand Down Expand Up @@ -86,8 +86,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def typeName: String = "vector"

private[spark] override def asNullable: VectorUDT = this
}

private[spark] object VectorUDT {

private[this] val _sqlType = {
val sqlType = {
// type: 0 = sparse, 1 = dense
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
// vectors. The "values" field is nullable because we might want to add binary vectors later,
Expand Down
200 changes: 133 additions & 67 deletions python/pyspark/ml/connect/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,107 @@

import pyspark.sql.connect.proto as pb2
from pyspark.ml.linalg import (
Vectors,
Matrices,
VectorUDT,
MatrixUDT,
DenseVector,
SparseVector,
DenseMatrix,
SparseMatrix,
)
from pyspark.sql.connect.expressions import LiteralExpression

if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.ml.param import Params


def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Param:
if isinstance(value, DenseVector):
return pb2.Param(vector=pb2.Vector(dense=pb2.Vector.Dense(value=value.values.tolist())))
elif isinstance(value, SparseVector):
return pb2.Param(
vector=pb2.Vector(
sparse=pb2.Vector.Sparse(
size=value.size, index=value.indices.tolist(), value=value.values.tolist()
)
)
)
elif isinstance(value, DenseMatrix):
return pb2.Param(
matrix=pb2.Matrix(
dense=pb2.Matrix.Dense(
num_rows=value.numRows, num_cols=value.numCols, value=value.values.tolist()
)
)
)
def literal_null() -> pb2.Expression.Literal:
dt = pb2.DataType()
dt.null.CopyFrom(pb2.DataType.NULL())
return pb2.Expression.Literal(null=dt)


def build_int_list(value: List[int]) -> pb2.Expression.Literal:
p = pb2.Expression.Literal()
p.specialized_array.ints.values.extend(value)
return p


def build_float_list(value: List[float]) -> pb2.Expression.Literal:
p = pb2.Expression.Literal()
p.specialized_array.doubles.values.extend(value)
return p


def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal:
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.connect.expressions import LiteralExpression

if isinstance(value, SparseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# size
p.struct.elements.append(pb2.Expression.Literal(integer=value.size))
# indices
p.struct.elements.append(build_int_list(value.indices.tolist()))
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
return p

elif isinstance(value, DenseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# size = null
p.struct.elements.append(literal_null())
# indices = null
p.struct.elements.append(literal_null())
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
return p

elif isinstance(value, SparseMatrix):
return pb2.Param(
matrix=pb2.Matrix(
sparse=pb2.Matrix.Sparse(
num_rows=value.numRows,
num_cols=value.numCols,
colptr=value.colPtrs.tolist(),
row_index=value.rowIndices.tolist(),
value=value.values.tolist(),
)
)
)
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# numRows
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
# numCols
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
# colPtrs
p.struct.elements.append(build_int_list(value.colPtrs.tolist()))
# rowIndices
p.struct.elements.append(build_int_list(value.rowIndices.tolist()))
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
# isTransposed
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
return p

elif isinstance(value, DenseMatrix):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# numRows
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
# numCols
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
# colPtrs = null
p.struct.elements.append(literal_null())
# rowIndices = null
p.struct.elements.append(literal_null())
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
# isTransposed
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
return p

else:
literal = LiteralExpression._from_value(value).to_plan(client).literal
return pb2.Param(literal=literal)
return LiteralExpression._from_value(value).to_plan(client).literal


def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
Expand All @@ -80,38 +133,51 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
return result


def deserialize_param(param: pb2.Param) -> Any:
if param.HasField("literal"):
return LiteralExpression._to_value(param.literal)
if param.HasField("vector"):
vector = param.vector
if vector.HasField("dense"):
return Vectors.dense(vector.dense.value)
elif vector.HasField("sparse"):
return Vectors.sparse(vector.sparse.size, vector.sparse.index, vector.sparse.value)
else:
raise ValueError("Unsupported vector type")
if param.HasField("matrix"):
matrix = param.matrix
if matrix.HasField("dense"):
return DenseMatrix(
matrix.dense.num_rows,
matrix.dense.num_cols,
matrix.dense.value,
matrix.dense.is_transposed,
)
elif matrix.HasField("sparse"):
return Matrices.sparse(
matrix.sparse.num_rows,
matrix.sparse.num_cols,
matrix.sparse.colptr,
matrix.sparse.row_index,
matrix.sparse.value,
)
def deserialize_param(literal: pb2.Expression.Literal) -> Any:
from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type
from pyspark.sql.connect.expressions import LiteralExpression

if literal.HasField("struct"):
s = literal.struct
schema = proto_schema_to_pyspark_data_type(s.struct_type)

if schema == VectorUDT.sqlType():
assert len(s.elements) == 4
tpe = s.elements[0].byte
if tpe == 0:
size = s.elements[1].integer
indices = s.elements[2].specialized_array.ints.values
values = s.elements[3].specialized_array.doubles.values
return SparseVector(size, indices, values)
elif tpe == 1:
values = s.elements[3].specialized_array.doubles.values
return DenseVector(values)
else:
raise ValueError(f"Unknown Vector type {tpe}")

elif schema == MatrixUDT.sqlType():
assert len(s.elements) == 7
tpe = s.elements[0].byte
if tpe == 0:
numRows = s.elements[1].integer
numCols = s.elements[2].integer
colPtrs = s.elements[3].specialized_array.ints.values
rowIndices = s.elements[4].specialized_array.ints.values
values = s.elements[5].specialized_array.doubles.values
isTransposed = s.elements[6].boolean
return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
elif tpe == 1:
numRows = s.elements[1].integer
numCols = s.elements[2].integer
values = s.elements[5].specialized_array.doubles.values
isTransposed = s.elements[6].boolean
return DenseMatrix(numRows, numCols, values, isTransposed)
else:
raise ValueError(f"Unknown Matrix type {tpe}")
else:
raise ValueError("Unsupported matrix type")

raise ValueError("Unsupported param type")
raise ValueError(f"Unsupported parameter struct {schema}")
else:
return LiteralExpression._to_value(literal)


def deserialize(ml_command_result_properties: Dict[str, Any]) -> Any:
Expand All @@ -126,7 +192,7 @@ def deserialize(ml_command_result_properties: Dict[str, Any]) -> Any:


def serialize_ml_params(instance: "Params", client: "SparkConnectClient") -> pb2.MlParams:
params: Mapping[str, pb2.Param] = {
params: Mapping[str, pb2.Expression.Literal] = {
k.name: serialize_param(v, client) for k, v in instance._paramMap.items()
}
return pb2.MlParams(params=params)
14 changes: 13 additions & 1 deletion python/pyspark/sql/connect/proto/common_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 \n\x0cuse_off_heap\x18\x03 \x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 \x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 \x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3\x01\n\x17\x45xecutorResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x03R\x06\x61mount\x12.\n\x10\x64iscovery_script\x18\x03 \x01(\tH\x00R\x0f\x64iscoveryScript\x88\x01\x01\x12\x1b\n\x06vendor\x18\x04 \x01(\tH\x01R\x06vendor\x88\x01\x01\x42\x13\n\x11_discovery_scriptB\t\n\x07_vendor"R\n\x13TaskResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x01R\x06\x61mount"\xa5\x03\n\x0fResourceProfile\x12\x64\n\x12\x65xecutor_resources\x18\x01 \x03(\x0b\x32\x35.spark.connect.ResourceProfile.ExecutorResourcesEntryR\x11\x65xecutorResources\x12X\n\x0etask_resources\x18\x02 \x03(\x0b\x32\x31.spark.connect.ResourceProfile.TaskResourcesEntryR\rtaskResources\x1al\n\x16\x45xecutorResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12<\n\x05value\x18\x02 \x01(\x0b\x32&.spark.connect.ExecutorResourceRequestR\x05value:\x02\x38\x01\x1a\x64\n\x12TaskResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.TaskResourceRequestR\x05value:\x02\x38\x01"X\n\x06Origin\x12\x42\n\rpython_origin\x18\x01 \x01(\x0b\x32\x1b.spark.connect.PythonOriginH\x00R\x0cpythonOriginB\n\n\x08\x66unction"G\n\x0cPythonOrigin\x12\x1a\n\x08\x66ragment\x18\x01 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x02 \x01(\tR\x08\x63\x61llSiteB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 \n\x0cuse_off_heap\x18\x03 \x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 \x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 \x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3\x01\n\x17\x45xecutorResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x03R\x06\x61mount\x12.\n\x10\x64iscovery_script\x18\x03 \x01(\tH\x00R\x0f\x64iscoveryScript\x88\x01\x01\x12\x1b\n\x06vendor\x18\x04 \x01(\tH\x01R\x06vendor\x88\x01\x01\x42\x13\n\x11_discovery_scriptB\t\n\x07_vendor"R\n\x13TaskResourceRequest\x12#\n\rresource_name\x18\x01 \x01(\tR\x0cresourceName\x12\x16\n\x06\x61mount\x18\x02 \x01(\x01R\x06\x61mount"\xa5\x03\n\x0fResourceProfile\x12\x64\n\x12\x65xecutor_resources\x18\x01 \x03(\x0b\x32\x35.spark.connect.ResourceProfile.ExecutorResourcesEntryR\x11\x65xecutorResources\x12X\n\x0etask_resources\x18\x02 \x03(\x0b\x32\x31.spark.connect.ResourceProfile.TaskResourcesEntryR\rtaskResources\x1al\n\x16\x45xecutorResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12<\n\x05value\x18\x02 \x01(\x0b\x32&.spark.connect.ExecutorResourceRequestR\x05value:\x02\x38\x01\x1a\x64\n\x12TaskResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.TaskResourceRequestR\x05value:\x02\x38\x01"X\n\x06Origin\x12\x42\n\rpython_origin\x18\x01 \x01(\x0b\x32\x1b.spark.connect.PythonOriginH\x00R\x0cpythonOriginB\n\n\x08\x66unction"G\n\x0cPythonOrigin\x12\x1a\n\x08\x66ragment\x18\x01 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x02 \x01(\tR\x08\x63\x61llSite"\x1f\n\x05\x42ools\x12\x16\n\x06values\x18\x01 \x03(\x08R\x06values"\x1e\n\x04Ints\x12\x16\n\x06values\x18\x01 \x03(\x05R\x06values"\x1f\n\x05Longs\x12\x16\n\x06values\x18\x01 \x03(\x03R\x06values" \n\x06\x46loats\x12\x16\n\x06values\x18\x01 \x03(\x02R\x06values"!\n\x07\x44oubles\x12\x16\n\x06values\x18\x01 \x03(\x01R\x06values"!\n\x07Strings\x12\x16\n\x06values\x18\x01 \x03(\tR\x06valuesB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)

_globals = globals()
Expand Down Expand Up @@ -70,4 +70,16 @@
_globals["_ORIGIN"]._serialized_end = 1091
_globals["_PYTHONORIGIN"]._serialized_start = 1093
_globals["_PYTHONORIGIN"]._serialized_end = 1164
_globals["_BOOLS"]._serialized_start = 1166
_globals["_BOOLS"]._serialized_end = 1197
_globals["_INTS"]._serialized_start = 1199
_globals["_INTS"]._serialized_end = 1229
_globals["_LONGS"]._serialized_start = 1231
_globals["_LONGS"]._serialized_end = 1262
_globals["_FLOATS"]._serialized_start = 1264
_globals["_FLOATS"]._serialized_end = 1296
_globals["_DOUBLES"]._serialized_start = 1298
_globals["_DOUBLES"]._serialized_end = 1331
_globals["_STRINGS"]._serialized_start = 1333
_globals["_STRINGS"]._serialized_end = 1366
# @@protoc_insertion_point(module_scope)
Loading

0 comments on commit 2721a50

Please sign in to comment.