Skip to content

Commit

Permalink
[backport] [pyspark] rework transform to reuse same code (#9292) (#9558)
Browse files Browse the repository at this point in the history
Co-authored-by: Bobby Wang <[email protected]>
  • Loading branch information
trivialfis and wbo4958 authored Sep 7, 2023
1 parent 3fde936 commit 4d387cb
Showing 1 changed file with 123 additions and 129 deletions.
252 changes: 123 additions & 129 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train

from .._typing import ArrayLike
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
alias,
Expand Down Expand Up @@ -1117,12 +1118,86 @@ def _get_feature_col(
)
return features_col, feature_col_names

def _get_pred_contrib_col_name(self) -> Optional[str]:
"""Return the pred_contrib_col col name"""
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)

return pred_contrib_col_name

def _out_schema(self) -> Tuple[bool, str]:
"""Return the bool to indicate if it's a single prediction, true is single prediction,
and the returned type of the user-defined function. The value must
be a DDL-formatted type string."""

if self._get_pred_contrib_col_name() is not None:
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"

return True, "double"

def _get_predict_func(self) -> Callable:
"""Return the true prediction function which will be running on the executor side"""

predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()

def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
) -> Union[pd.DataFrame, pd.Series]:
data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)

if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
return pd.DataFrame(data=data)

return data[pred.prediction]

return _predict

def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
"""Post process of transform"""
prediction_col_name = self.getOrDefault(self.predictionCol)
single_pred, _ = self._out_schema()

if single_pred:
if prediction_col_name:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)

if prediction_col_name:
dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)

pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)

dataset = dataset.drop(pred_struct_col)
return dataset

def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()

has_base_margin = False
if (
Expand All @@ -1137,18 +1212,9 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
predict_func = self._get_predict_func()

single_pred = True
schema = "double"
if pred_contrib_col_name:
single_pred = False
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
_, schema = self._out_schema()

@pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
Expand All @@ -1168,48 +1234,14 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
else:
base_margin = None

data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)

if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
yield pd.DataFrame(data=data)
else:
yield data[pred.prediction]
yield predict_func(model, X, base_margin)

if has_base_margin:
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(*features_col))

prediction_col_name = self.getOrDefault(self.predictionCol)

if single_pred:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)

dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)

if pred_contrib_col_name:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)

dataset = dataset.drop(pred_struct_col)

return dataset
return self._post_transform(dataset, pred_col)


class _ClassificationModel( # pylint: disable=abstract-method
Expand All @@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
.. Note:: This API is experimental.
"""

def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
def _out_schema(self) -> Tuple[bool, str]:
schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if self._get_pred_contrib_col_name() is not None:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"

has_base_margin = False
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
)
return False, schema

def _get_predict_func(self) -> Callable:
predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()

def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if margins.ndim == 1:
Expand All @@ -1251,76 +1282,38 @@ def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
class_probs = softmax(raw_preds, axis=1)
return raw_preds, class_probs

features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)

schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if pred_contrib_col_name:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"

@pandas_udf(schema) # type: ignore
def predict_udf(
iterator: Iterator[Tuple[pd.Series, ...]]
) -> Iterator[pd.DataFrame]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model
for data in iterator:
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
if feature_col_names is not None:
X = data[feature_col_names] # type: ignore
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = stack_series(data[alias.margin])
else:
base_margin = None

margins = model.predict(
X,
base_margin=base_margin,
output_margin=True,
validate_features=False,
**predict_params,
)
raw_preds, class_probs = transform_margin(margins)

# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1)
result: Dict[str, pd.Series] = {
pred.raw_prediction: pd.Series(list(raw_preds)),
pred.prediction: pd.Series(preds),
pred.probability: pd.Series(list(class_probs)),
}
def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
) -> Union[pd.DataFrame, pd.Series]:
margins = model.predict(
X,
base_margin=base_margin,
output_margin=True,
validate_features=False,
**predict_params,
)
raw_preds, class_probs = transform_margin(margins)

# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1)
result: Dict[str, pd.Series] = {
pred.raw_prediction: pd.Series(list(raw_preds)),
pred.prediction: pd.Series(preds),
pred.probability: pd.Series(list(class_probs)),
}

if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))

yield pd.DataFrame(data=result)
return pd.DataFrame(data=result)

if has_base_margin:
pred_struct = predict_udf(struct(*features_col, base_margin_col))
else:
pred_struct = predict_udf(struct(*features_col))
return _predict

def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_struct)
dataset = dataset.withColumn(pred_struct_col, pred_col)

raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
if raw_prediction_col_name:
Expand All @@ -1342,7 +1335,8 @@ def predict_udf(
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
)

if pred_contrib_col_name:
pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
getattr(col(pred_struct_col), pred.pred_contrib),
Expand Down

0 comments on commit 4d387cb

Please sign in to comment.