From 87e48ba41604cd74c986de33a7c45359fede6f51 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 9 Oct 2024 09:56:49 -0700 Subject: [PATCH] Provide 24.10 version bumps (#406) * Add support for XGBoost UBJSON * Bump package versions in Conda envs * Replace `convert_sklearn` script to use latest Treelite API * Update `generate_example_model.py` to use latest XGBoost API --- conda/environments/rapids_triton_dev.yml | 2 + conda/environments/triton_benchmark.yml | 10 +-- conda/environments/triton_test.yml | 16 ++--- conda/environments/triton_test_no_client.yml | 12 ++-- docs/model_config.md | 13 ++-- docs/sklearn_and_cuml.md | 4 +- notebooks/faq/FAQs.ipynb | 2 +- ops/Dockerfile | 4 +- qa/L0_e2e/generate_example_model.py | 28 +++++--- qa/L0_e2e/test_model.py | 69 +++++++++++++++----- qa/generate_example_models.sh | 15 ++++- scripts/convert_cuml.py | 12 ++-- scripts/convert_sklearn | 24 ------- scripts/convert_sklearn.py | 42 ++++++++++++ scripts/environment.yml | 6 +- src/model.h | 3 + src/serialization.h | 13 +++- src/tl_utils.h | 8 +++ 18 files changed, 191 insertions(+), 92 deletions(-) delete mode 100755 scripts/convert_sklearn create mode 100755 scripts/convert_sklearn.py diff --git a/conda/environments/rapids_triton_dev.yml b/conda/environments/rapids_triton_dev.yml index 9ecb60f9..ce367a5f 100644 --- a/conda/environments/rapids_triton_dev.yml +++ b/conda/environments/rapids_triton_dev.yml @@ -6,4 +6,6 @@ dependencies: - ccache - cmake>=3.26.4,!=3.30.0 - ninja + # TODO(hcho3): Remove the pin when + # https://github.com/triton-inference-server/common/pull/114 is merged - rapidjson>=1.1.0,<1.1.0.post* diff --git a/conda/environments/triton_benchmark.yml b/conda/environments/triton_benchmark.yml index 6ce08e3b..10772a9a 100644 --- a/conda/environments/triton_benchmark.yml +++ b/conda/environments/triton_benchmark.yml @@ -5,14 +5,14 @@ channels: - rapidsai dependencies: - cuda-version=11.8 - - cudf=23.12 - - libcusolver<=11.4.1.48 - - libcusparse<=12.0 + - cudf=24.08 + - libcusolver + - libcusparse - matplotlib - pip - python - scipy - pip: - tritonclient[all] - - protobuf==3.20.1 - - git+https://github.com/rapidsai/rapids-triton.git@branch-23.12#subdirectory=python + - protobuf + - git+https://github.com/rapidsai/rapids-triton.git@branch-24.04#subdirectory=python diff --git a/conda/environments/triton_test.yml b/conda/environments/triton_test.yml index 766f97d0..26968ff7 100644 --- a/conda/environments/triton_test.yml +++ b/conda/environments/triton_test.yml @@ -7,19 +7,19 @@ dependencies: - aws-sdk-cpp - clang-tools=11.1.0 - cuda-version=11.8 - - cudf=23.12 - - cuml=23.12 + - cudf=24.08 + - cuml=24.08 - flake8 - - hypothesis<6.46.8 + - hypothesis - lightgbm - matplotlib - pip - pytest - python - - rapidsai::xgboost>=1.7 - - scikit-learn=1.2.0 - - treelite + - rapidsai::xgboost>=2.1 + - scikit-learn>=1.5 + - treelite>=4.3 - pip: - tritonclient[all] - - protobuf==3.20.1 - - git+https://github.com/rapidsai/rapids-triton.git@branch-23.12#subdirectory=python + - protobuf + - git+https://github.com/rapidsai/rapids-triton.git@branch-24.04#subdirectory=python diff --git a/conda/environments/triton_test_no_client.yml b/conda/environments/triton_test_no_client.yml index cebbaea0..73dbe5c0 100644 --- a/conda/environments/triton_test_no_client.yml +++ b/conda/environments/triton_test_no_client.yml @@ -7,15 +7,15 @@ dependencies: - aws-sdk-cpp - clang-tools=11.1.0 - cuda-version=11.8 - - cudf=23.12 - - cuml=23.12 + - cudf=24.08 + - cuml=24.08 - flake8 - - hypothesis<6.46.8 + - hypothesis - lightgbm - pip - pytest - python - python-rapidjson - - rapidsai::xgboost>=1.7 - - scikit-learn=1.2.0 - - treelite + - rapidsai::xgboost>=2.1 + - scikit-learn>=1.5 + - treelite>=4.3 diff --git a/docs/model_config.md b/docs/model_config.md index e8a8812d..b1630ebf 100644 --- a/docs/model_config.md +++ b/docs/model_config.md @@ -70,7 +70,7 @@ instance_group [{ kind: KIND_AUTO }] parameters [ { key: "model_type" - value: { string_value: "xgboost_json" } + value: { string_value: "xgboost_ubj" } }, { key: "output_class" @@ -185,23 +185,24 @@ Treelite's checkpoint format. For more information, see [Model Support](model_support.md). The `model_type` option is used to indicate which of these serialization -formats your model uses: `xgboost` for XGBoost binary, `xgboost_json` for -XGBoost JSON, `lightgbm` for LightGBM, or `treelite_checkpoint` for -Treelite: +formats your model uses: `xgboost_ubj` for XGBoost UBJSON, `xgboost_json` for +XGBoost JSON, `xgboost` for XGBoost binary (legacy), `lightgbm` for LightGBM, +or `treelite_checkpoint` for Treelite: ``` parameters [ { key: "model_type" - value: { string_value: "xgboost_json" } + value: { string_value: "xgboost_ubj" } } ] ``` #### Model Filenames For each model type, Triton expects a particular default filename: -- `xgboost.model` for XGBoost Binary +- `xgboost.ubj` for XGBoost UBJSON - `xgboost.json` for XGBoost JSON +- `xgboost.model` for XGBoost Binary (Legacy) - `model.txt` for LightGBM - `checkpoint.tl` for Treelite It is recommended that you use these filenames, but custom filenames can be diff --git a/docs/sklearn_and_cuml.md b/docs/sklearn_and_cuml.md index 85e8cbe8..fece57a3 100644 --- a/docs/sklearn_and_cuml.md +++ b/docs/sklearn_and_cuml.md @@ -48,7 +48,7 @@ model framework in Triton simply by exporting to the binary checkpoint format. The FIL backend repo includes scripts for easy conversion from pickle-serialized cuML or Scikit-Learn models to Treelite checkpoints. You can download the relevant script for Scikit-Learn -[here](https://raw.githubusercontent.com/triton-inference-server/fil_backend/main/scripts/convert_sklearn) +[here](https://raw.githubusercontent.com/triton-inference-server/fil_backend/main/scripts/convert_sklearn.py) and for cuML [here](https://raw.githubusercontent.com/triton-inference-server/fil_backend/main/scripts/convert_cuml.py). @@ -89,7 +89,7 @@ model_repository/ Then perform the conversion by running either: ```bash -./convert_sklearn model_repository/fil/1/model.pkl +./convert_sklearn.py model_repository/fil/1/model.pkl ``` for Scikit-Learn models or ```bash diff --git a/notebooks/faq/FAQs.ipynb b/notebooks/faq/FAQs.ipynb index 7f82cf34..873f482d 100644 --- a/notebooks/faq/FAQs.ipynb +++ b/notebooks/faq/FAQs.ipynb @@ -593,7 +593,7 @@ " archival_path = os.path.join(VERSIONED_DIR, 'model.pkl')\n", " shutil.copy(MODEL_PATH, archival_path)\n", " \n", - " !../../scripts/convert_sklearn {archival_path}" + " !../../scripts/convert_sklearn.py {archival_path}" ] }, { diff --git a/ops/Dockerfile b/ops/Dockerfile index 9546856d..d1e85046 100644 --- a/ops/Dockerfile +++ b/ops/Dockerfile @@ -3,7 +3,7 @@ # Arguments for controlling build details ########################################################################################### # Version of Triton to use -ARG TRITON_VERSION=24.08 +ARG TRITON_VERSION=24.09 # Base container image ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:${TRITON_VERSION}-py3 # Whether or not to enable GPU build @@ -56,7 +56,7 @@ RUN conda run --no-capture-output -n triton_test \ FROM wheel-install-${USE_CLIENT_WHEEL} as conda-test RUN conda run --no-capture-output -n triton_test \ - pip install git+https://github.com/rapidsai/rapids-triton.git@branch-21.12#subdirectory=python + pip install git+https://github.com/rapidsai/rapids-triton.git@branch-24.04#subdirectory=python RUN conda-pack --ignore-missing-files -n triton_test -o /tmp/env.tar \ && mkdir /conda/test/ \ && cd /conda/test/ \ diff --git a/qa/L0_e2e/generate_example_model.py b/qa/L0_e2e/generate_example_model.py index fd87ba30..c8477cfd 100644 --- a/qa/L0_e2e/generate_example_model.py +++ b/qa/L0_e2e/generate_example_model.py @@ -72,11 +72,10 @@ def train_xgboost_classifier(data, labels, depth=25, trees=100): training_params = { "eval_metric": "error", "objective": "binary:logistic", - "tree_method": "gpu_hist", + "tree_method": "hist", + "device": "cuda", "max_depth": depth, "n_estimators": trees, - "use_label_encoder": False, - "predictor": "gpu_predictor", } model = xgb.XGBClassifier(**training_params) @@ -192,10 +191,10 @@ def train_xgboost_regressor(data, targets, depth=25, trees=100): training_params = { "objective": "reg:squarederror", - "tree_method": "gpu_hist", + "tree_method": "hist", + "device": "cuda", "max_depth": depth, "n_estimators": trees, - "predictor": "gpu_predictor", } model = xgb.XGBRegressor(**training_params) @@ -307,13 +306,19 @@ def generate_model( def serialize_model(model, directory, output_format="xgboost"): if output_format == "xgboost": - model_path = os.path.join(directory, "xgboost.model") + model_path = os.path.join(directory, "xgboost.deprecated") model.save_model(model_path) - return model_path + new_model_path = os.path.join(directory, "xgboost.model") + os.rename(model_path, new_model_path) + return new_model_path if output_format == "xgboost_json": model_path = os.path.join(directory, "xgboost.json") model.save_model(model_path) return model_path + if output_format == "xgboost_ubj": + model_path = os.path.join(directory, "xgboost.ubj") + model.save_model(model_path) + return model_path if output_format == "lightgbm": model_path = os.path.join(directory, "model.txt") model.save_model(model_path) @@ -462,6 +467,8 @@ def build_model( if output_format is None: if model_type == "xgboost": + # TODO(hcho3): Update to "xgboost_ubj" when XGBoost removes support + # for legacy binary format output_format = "xgboost" elif model_type == "lightgbm": output_format = "lightgbm" @@ -471,7 +478,10 @@ def build_model( raise RuntimeError('Unknown model type "{}"'.format(model_type)) if ( - (model_type == "xgboost" and output_format not in {"xgboost", "xgboost_json"}) + ( + model_type == "xgboost" + and output_format not in {"xgboost", "xgboost_json", "xgboost_ubj"} + ) or (model_type == "lightgbm" and output_format not in {"lightgbm"}) or (model_type == "sklearn" and output_format not in {"pickle"}) or (model_type == "cuml" and output_format not in {"pickle"}) @@ -545,7 +555,7 @@ def parse_args(): ) parser.add_argument( "--format", - choices=("xgboost", "xgboost_json", "lightgbm", "pickle"), + choices=("xgboost", "xgboost_json", "xgboost_ubj", "lightgbm", "pickle"), default=None, help="serialization format for model", ) diff --git a/qa/L0_e2e/test_model.py b/qa/L0_e2e/test_model.py index ecfc80d4..d2d94531 100644 --- a/qa/L0_e2e/test_model.py +++ b/qa/L0_e2e/test_model.py @@ -36,6 +36,7 @@ "xgboost", "xgboost_shap", "xgboost_json", + "xgboost_ubj", "lightgbm", "lightgbm_rf", "regression", @@ -66,6 +67,17 @@ def valid_shm_modes(): return tuple(modes) +# TODO(hcho3): Remove once we fix the flakiness of CUDA shared mem +# See https://github.com/triton-inference-server/server/issues/7688 +def shared_mem_parametrize(): + params = [None] + if "cuda" in valid_shm_modes(): + params.append( + pytest.param("cuda", marks=pytest.mark.xfail(reason="shared mem is flaky")), + ) + return params + + @pytest.fixture(scope="session") def client(): """A RAPIDS-Triton client for submitting inference requests""" @@ -98,27 +110,46 @@ class GTILModel: """A compatibility wrapper for executing models with GTIL""" def __init__(self, model_path, model_format, output_class): - if model_format == "treelite_checkpoint": + if model_format == "xgboost": + self.tl_model = treelite.frontend.load_xgboost_model_legacy_binary( + model_path + ) + elif model_format == "xgboost_json": + self.tl_model = treelite.frontend.load_xgboost_model( + model_path, format_choice="json" + ) + elif model_format == "xgboost_ubj": + self.tl_model = treelite.frontend.load_xgboost_model( + model_path, format_choice="ubjson" + ) + elif model_format == "lightgbm": + self.tl_model = treelite.frontend.load_lightgbm_model(model_path) + elif model_format == "treelite_checkpoint": self.tl_model = treelite.Model.deserialize(model_path) - else: - self.tl_model = treelite.Model.load(model_path, model_format) self.output_class = output_class def _predict(self, arr): - return treelite.gtil.predict(self.tl_model, arr) + result = treelite.gtil.predict(self.tl_model, arr) + # GTIL always returns prediction result with dimensions + # (num_row, num_target, num_class) + assert len(result.shape) == 3 + # We don't test multi-target models + # TODO(hcho3): Add coverage for multi-target models + assert result.shape[1] == 1 + return result[:, 0, :] def predict_proba(self, arr): result = self._predict(arr) - if len(result.shape) > 1: + if result.shape[1] > 1: return result else: - return np.transpose(np.vstack((1 - result, result))) + return np.hstack((1 - result, result)) def predict(self, arr): if self.output_class: return np.argmax(self.predict_proba(arr), axis=1) else: - return self._predict(arr) + return self._predict(arr).squeeze() class GroundTruthModel: @@ -144,6 +175,8 @@ def __init__( model_path = os.path.join(model_dir, "xgboost.model") elif model_format == "xgboost_json": model_path = os.path.join(model_dir, "xgboost.json") + elif model_format == "xgboost_ubj": + model_path = os.path.join(model_dir, "xgboost.ubj") elif model_format == "lightgbm": model_path = os.path.join(model_dir, "model.txt") elif model_format == "treelite_checkpoint": @@ -220,12 +253,13 @@ def model_data(request, client, model_repo): ) +@pytest.mark.parametrize("shared_mem", shared_mem_parametrize()) @given(hypothesis_data=st.data()) @settings( deadline=None, suppress_health_check=(HealthCheck.too_slow, HealthCheck.filter_too_much), ) -def test_small(client, model_data, hypothesis_data): +def test_small(shared_mem, client, model_data, hypothesis_data): """Test Triton-served model on many small Hypothesis-generated examples""" all_model_inputs = defaultdict(list) total_output_sizes = {} @@ -251,9 +285,6 @@ def test_small(client, model_data, hypothesis_data): model_output_sizes = { name: size for name, size in model_data.output_sizes.items() } - shared_mem = hypothesis_data.draw( - st.one_of(st.just(mode) for mode in valid_shm_modes()) - ) result = client.predict( model_data.name, model_inputs, @@ -298,11 +329,11 @@ def test_small(client, model_data, hypothesis_data): ) # Test entire batch of Hypothesis-generated inputs at once - shared_mem = hypothesis_data.draw( - st.one_of(st.just(mode) for mode in valid_shm_modes()) - ) all_triton_outputs = client.predict( - model_data.name, all_model_inputs, total_output_sizes, shared_mem=shared_mem + model_data.name, + all_model_inputs, + total_output_sizes, + shared_mem=shared_mem, ) for output_name in sorted(ground_truth.keys()): @@ -324,7 +355,7 @@ def test_small(client, model_data, hypothesis_data): ) -@pytest.mark.parametrize("shared_mem", valid_shm_modes()) +@pytest.mark.parametrize("shared_mem", shared_mem_parametrize()) def test_max_batch(client, model_data, shared_mem): """Test processing of a single maximum-sized batch""" max_inputs = { @@ -335,9 +366,11 @@ def test_max_batch(client, model_data, shared_mem): name: size * model_data.max_batch_size for name, size in model_data.output_sizes.items() } - shared_mem = valid_shm_modes()[0] result = client.predict( - model_data.name, max_inputs, model_output_sizes, shared_mem=shared_mem + model_data.name, + max_inputs, + model_output_sizes, + shared_mem=shared_mem, ) ground_truth = model_data.ground_truth_model.predict(max_inputs) diff --git a/qa/generate_example_models.sh b/qa/generate_example_models.sh index 740a232b..f94ff13b 100755 --- a/qa/generate_example_models.sh +++ b/qa/generate_example_models.sh @@ -25,7 +25,7 @@ SCRIPTS_DIR="${QA_DIR}/../scripts" MODEL_REPO="${QA_DIR}/L0_e2e/model_repository" GENERATOR_SCRIPT="python ${QA_DIR}/L0_e2e/generate_example_model.py" -SKLEARN_CONVERTER="${SCRIPTS_DIR}/convert_sklearn" +SKLEARN_CONVERTER="${SCRIPTS_DIR}/convert_sklearn.py" CUML_CONVERTER="${SCRIPTS_DIR}/convert_cuml.py" models=() @@ -56,6 +56,19 @@ then models+=( $name ) fi +name=xgboost_ubj +if [ $RETRAIN -ne 0 ] || [ ! -d "${MODEL_REPO}/${name}" ] +then + ${GENERATOR_SCRIPT} \ + --name $name \ + --format xgboost_ubj \ + --depth 7 \ + --trees 500 \ + --features 500 \ + --predict_proba + models+=( $name ) +fi + name=xgboost_shap if [ $RETRAIN -ne 0 ] || [ ! -d "${MODEL_REPO}/${name}" ] then diff --git a/scripts/convert_cuml.py b/scripts/convert_cuml.py index 7ae16446..89d3031a 100755 --- a/scripts/convert_cuml.py +++ b/scripts/convert_cuml.py @@ -21,7 +21,7 @@ """ import argparse -import os +import pathlib import pickle if __name__ == "__main__": @@ -29,10 +29,10 @@ parser.add_argument("pickle_file", help="Path to the pickle file to convert") args = parser.parse_args() - with open(args.pickle_file, "rb") as file_: - model = pickle.load(file_) + with open(args.pickle_file, "rb") as f: + model = pickle.load(f) - model_dir = os.path.dirname(args.pickle_file) - out_path = os.path.join(model_dir, "checkpoint.tl") + model_dir = pathlib.Path(args.pickle_file).resolve().parent + out_path = model_dir / "checkpoint.tl" - model.convert_to_treelite_model().to_treelite_checkpoint(out_path) + model.convert_to_treelite_model().to_treelite_checkpoint(str(out_path)) diff --git a/scripts/convert_sklearn b/scripts/convert_sklearn deleted file mode 100755 index 9a1bbab4..00000000 --- a/scripts/convert_sklearn +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -if [ $# -ne 1 ] || [ ! -f $1 ] -then - echo "USAGE: convert_sklearn PKL_FILE" -else - out_file="$(dirname $1)/checkpoint.tl" - - python -m treelite.serialize --input-model "$1" \ - --input-model-type sklearn_pkl --output-checkpoint "$out_file" -fi diff --git a/scripts/convert_sklearn.py b/scripts/convert_sklearn.py new file mode 100755 index 00000000..0c8804c2 --- /dev/null +++ b/scripts/convert_sklearn.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""sklearn RF/GBDT to Treelite checkpoint converter + +Given a path to a pickle file containing a scikit-learn random forest (or +gradient boosting) model, this script will generate a Treelite checkpoint file +representation of the model in the same directory. +""" + +import argparse +import pathlib +import pickle + +import treelite + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("pickle_file", help="Path to the pickle file to convert") + args = parser.parse_args() + + with open(args.pickle_file, "rb") as f: + model = pickle.load(f) + + model_dir = pathlib.Path(args.pickle_file).resolve().parent + out_path = model_dir / "checkpoint.tl" + + tl_model = treelite.sklearn.import_model(model) + tl_model.serialize(out_path) diff --git a/scripts/environment.yml b/scripts/environment.yml index 1daeefa7..e8e16979 100644 --- a/scripts/environment.yml +++ b/scripts/environment.yml @@ -5,7 +5,7 @@ channels: - rapidsai dependencies: - cuda-version=11.8 - - cuml=23.12 + - cuml=24.08 - python - - scikit-learn - - treelite + - scikit-learn>=1.5 + - treelite>=4.3 diff --git a/src/model.h b/src/model.h index 0e46da2e..9b715e83 100644 --- a/src/model.h +++ b/src/model.h @@ -151,6 +151,9 @@ struct RapidsModel : rapids::Model { case SerializationFormat::xgboost_json: path /= "xgboost.json"; break; + case SerializationFormat::xgboost_ubj: + path /= "xgboost.ubj"; + break; case SerializationFormat::lightgbm: path /= "model.txt"; break; diff --git a/src/serialization.h b/src/serialization.h index ab26f6d1..05ea0ee0 100644 --- a/src/serialization.h +++ b/src/serialization.h @@ -23,7 +23,13 @@ namespace triton { namespace backend { namespace NAMESPACE { -enum struct SerializationFormat { xgboost, xgboost_json, lightgbm, treelite }; +enum struct SerializationFormat { + xgboost, + xgboost_json, + xgboost_ubj, + lightgbm, + treelite +}; inline auto string_to_serialization(std::string const& type_string) @@ -34,6 +40,8 @@ string_to_serialization(std::string const& type_string) result = SerializationFormat::xgboost; } else if (type_string == "xgboost_json") { result = SerializationFormat::xgboost_json; + } else if (type_string == "xgboost_ubj") { + result = SerializationFormat::xgboost_ubj; } else if (type_string == "lightgbm") { result = SerializationFormat::lightgbm; } else if (type_string == "treelite_checkpoint") { @@ -60,6 +68,9 @@ serialization_to_string(SerializationFormat format) case SerializationFormat::xgboost_json: result = "xgboost_json"; break; + case SerializationFormat::xgboost_ubj: + result = "xgboost_ubj"; + break; case SerializationFormat::lightgbm: result = "lightgbm"; break; diff --git a/src/tl_utils.h b/src/tl_utils.h index 218edf77..304b7473 100644 --- a/src/tl_utils.h +++ b/src/tl_utils.h @@ -52,6 +52,14 @@ load_tl_base_model( model_file, config_str); break; } + case SerializationFormat::xgboost_ubj: { + auto config_str = + std::string("{\"allow_unknown_field\": ") + + std::string(xgboost_allow_unknown_field ? "true" : "false") + "}"; + result = treelite::model_loader::LoadXGBoostModelUBJSON( + model_file, config_str); + break; + } case SerializationFormat::lightgbm: result = treelite::model_loader::LoadLightGBMModel(model_file); break;