From f6f7d1bcf7931bb2a845daab88ae37852ca9ae2d Mon Sep 17 00:00:00 2001 From: Leon Kiefer Date: Mon, 22 Jan 2024 13:18:40 +0100 Subject: [PATCH] Add resolve_model_relative_to_config_file config option Signed-off-by: Leon Kiefer --- README.md | 4 ++++ ci/L0_backend_vllm/vllm_backend/test.sh | 15 ++++++++++++ .../vllm_backend/vllm_backend_test.py | 23 +++++++++++++++---- src/model.py | 8 +++++++ 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 969830dd..f84ecd2d 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,10 @@ Specifically, and [here](https://github.com/vllm-project/vllm/blob/ee8217e5bee5860469204ee57077a91138c9af02/vllm/engine/arg_utils.py#L201). +When using local model files, specify the path to the model in the `model` field. +By default relative paths are resolved relative to the working directory of the Triton server process. +To specify a path relative to the `model.json` file, set the `resolve_model_relative_to_config_file` field to `true`. + For multi-GPU support, EngineArgs like tensor_parallel_size can be specified in [model.json](samples/model_repository/vllm_model/1/model.json). diff --git a/ci/L0_backend_vllm/vllm_backend/test.sh b/ci/L0_backend_vllm/vllm_backend/test.sh index 32520b5d..5a8c66be 100755 --- a/ci/L0_backend_vllm/vllm_backend/test.sh +++ b/ci/L0_backend_vllm/vllm_backend/test.sh @@ -39,12 +39,26 @@ SAMPLE_MODELS_REPO="../../../samples/model_repository" EXPECTED_NUM_TESTS=3 rm -rf models && mkdir -p models + +# operational vllm model cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt +# python model mkdir -p models/add_sub/1/ wget -P models/add_sub/1/ https://raw.githubusercontent.com/triton-inference-server/python_backend/main/examples/add_sub/model.py wget -P models/add_sub https://raw.githubusercontent.com/triton-inference-server/python_backend/main/examples/add_sub/config.pbtxt +# local vllm model +cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_local +sed -i 's/"facebook\/opt-125m"/"./local_model"/' models/vllm_local/1/model.json +sed -i '/"model": /a "resolve_model_relative_to_config_file": true,' models/vllm_local/1/model.json +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/config.json +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/merges.txt +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/pytorch_model.bin +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/special_tokens_map.json +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/tokenizer_config.json +wget -P models/vllm_local/1/local_model https://huggingface.co/facebook/opt-125m/resolve/main/vocab.json + # Invalid model attribute cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_invalid_1/ sed -i 's/"disable_log_requests"/"invalid_attribute"/' models/vllm_invalid_1/1/model.json @@ -53,6 +67,7 @@ sed -i 's/"disable_log_requests"/"invalid_attribute"/' models/vllm_invalid_1/1/m cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_invalid_2/ sed -i 's/"facebook\/opt-125m"/"invalid_model"/' models/vllm_invalid_2/1/model.json + RET=0 run_server diff --git a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py index cd953746..314abebe 100644 --- a/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py +++ b/ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py @@ -41,6 +41,7 @@ def setUp(self): self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") self.vllm_model_name = "vllm_opt" self.python_model_name = "add_sub" + self.local_vllm_model_name = "vllm_local" def test_vllm_triton_backend(self): # Load both vllm and add_sub models @@ -60,9 +61,21 @@ def test_vllm_triton_backend(self): self.assertFalse(self.triton_client.is_model_ready(self.python_model_name)) # Test vllm model and unload vllm model - self._test_vllm_model(send_parameters_as_tensor=True) - self._test_vllm_model(send_parameters_as_tensor=False) + self._test_vllm_model(self.vllm_model_name, send_parameters_as_tensor=True) + self._test_vllm_model(self.vllm_model_name, send_parameters_as_tensor=False) self.triton_client.unload_model(self.vllm_model_name) + + def test_local_vllm_model(self): + # Load local vllm model + self.triton_client.load_model(self.local_vllm_model_name) + self.assertTrue(self.triton_client.is_model_ready(self.local_vllm_model_name)) + + # Test local vllm model + self._test_vllm_model(self.local_vllm_model_name, send_parameters_as_tensor=True) + self._test_vllm_model(self.local_vllm_model_name, send_parameters_as_tensor=False) + + # Unload local vllm model + self.triton_client.unload_model(self.local_vllm_model_name) def test_model_with_invalid_attributes(self): model_name = "vllm_invalid_1" @@ -74,7 +87,7 @@ def test_vllm_invalid_model_name(self): with self.assertRaises(InferenceServerException): self.triton_client.load_model(model_name) - def _test_vllm_model(self, send_parameters_as_tensor): + def _test_vllm_model(self, model_name, send_parameters_as_tensor): user_data = UserData() stream = False prompts = [ @@ -92,11 +105,11 @@ def _test_vllm_model(self, send_parameters_as_tensor): i, stream, sampling_parameters, - self.vllm_model_name, + model_name, send_parameters_as_tensor, ) self.triton_client.async_stream_infer( - model_name=self.vllm_model_name, + model_name=model_name, request_id=request_data["request_id"], inputs=request_data["inputs"], outputs=request_data["outputs"], diff --git a/src/model.py b/src/model.py index 80f51320..fd22d07a 100644 --- a/src/model.py +++ b/src/model.py @@ -112,6 +112,14 @@ def initialize(self, args): with open(engine_args_filepath) as file: vllm_engine_config = json.load(file) + # Resolve the model path relative to the config file + if vllm_engine_config.pop("resolve_model_relative_to_config_file", False): + vllm_engine_config["model"] = os.path.abspath( + os.path.join( + pb_utils.get_model_dir(), vllm_engine_config["model"] + ) + ) + # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs(**vllm_engine_config)