Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add instructions for running vLLM backend #8

Merged
merged 52 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
1688a33
Draft README and samples
dyastremsky Oct 10, 2023
0ba6200
Run pre-commit
dyastremsky Oct 10, 2023
a4921c1
Remove unused queue.
dyastremsky Oct 10, 2023
92124bf
Fixes for README
dyastremsky Oct 10, 2023
aa8a105
Add client.py shebang
dyastremsky Oct 10, 2023
ed108d0
Add Conda instructions.
dyastremsky Oct 10, 2023
c5213f6
Spacing, title
dyastremsky Oct 10, 2023
2c6881c
Switch i/o to lowercase
dyastremsky Oct 10, 2023
ac33407
Switch i/o to lowercase
dyastremsky Oct 10, 2023
d2fdb3f
Switch i/o to lowercase
dyastremsky Oct 10, 2023
02c1167
Switch i/o to lowercase
dyastremsky Oct 10, 2023
d164dab
Change client code to use lowercase inputs/outputs
dyastremsky Oct 10, 2023
5ed4d0e
Merge branch 'main' of https://github.com/triton-inference-server/vll…
dyastremsky Oct 10, 2023
0cd3d91
Merge branch 'dyas-README' of https://github.com/triton-inference-ser…
dyastremsky Oct 10, 2023
45a531f
Update client to use iterable client class
dyastremsky Oct 11, 2023
1e27105
Rename vLLM model, add note to config
dyastremsky Oct 11, 2023
97417c5
Remove unused imports and vars
dyastremsky Oct 11, 2023
d943de2
Clarify whaat Conda parameter is doing.
dyastremsky Oct 11, 2023
99943cc
Add clarifying note to model config
dyastremsky Oct 11, 2023
b08f426
Run pre-commit
dyastremsky Oct 11, 2023
682ad0c
Remove limitation, model name
dyastremsky Oct 11, 2023
e7578f1
Fix gen vllm env script name
dyastremsky Oct 11, 2023
502f4db
Update wording for supported models
dyastremsky Oct 11, 2023
ea35a73
Merge branch 'dyas-README' of https://github.com/triton-inference-ser…
dyastremsky Oct 11, 2023
fe06416
Update capitalization
dyastremsky Oct 11, 2023
0144d33
Update wording around shared memory across servers
dyastremsky Oct 11, 2023
0f0f968
Remove extra note about shared memory hangs across servers
dyastremsky Oct 11, 2023
b81574d
Fix line lengths and clarify wording.
dyastremsky Oct 11, 2023
faa29a6
Add container steps
dyastremsky Oct 12, 2023
4259a7e
Add links to engine args, define model.json
dyastremsky Oct 12, 2023
76c2d89
Change verbiage around vLLM engine models
dyastremsky Oct 12, 2023
31f1733
Fix links
dyastremsky Oct 12, 2023
76d0652
Fix links, grammar
dyastremsky Oct 12, 2023
a50ae8d
Remove Conda references.
dyastremsky Oct 12, 2023
edaff54
Fix client I/O and model names
dyastremsky Oct 12, 2023
33dbaed
Remove model name in config
dyastremsky Oct 12, 2023
6575197
Add generate endpoint, switch to min container
dyastremsky Oct 12, 2023
9effb18
Change to min
dyastremsky Oct 12, 2023
8dc3f51
Apply suggestions from code review
oandreeva-nv Oct 13, 2023
bf0d905
Update README.md
oandreeva-nv Oct 13, 2023
7ec9b5f
Add example model args, link to multi-server behavior
dyastremsky Oct 17, 2023
3b64abc
Format client input, add upstream tag.
dyastremsky Oct 17, 2023
3a3b326
Fix links, grammar
dyastremsky Oct 17, 2023
48e08e7
Add quotes to shm-region-prefix-name
dyastremsky Oct 17, 2023
9b4a193
Update sentence ordering, remove extra issues link
dyastremsky Oct 17, 2023
45be0f6
Modify input text example, one arg per line
dyastremsky Oct 17, 2023
204ce5a
Remove line about CUDA version compatibility.
dyastremsky Oct 18, 2023
8c9c4e7
Wording of Triton container option
dyastremsky Oct 18, 2023
3ab4774
Update wording of pre-built Docker container option
dyastremsky Oct 18, 2023
757e2b2
Update README.md wording
dyastremsky Oct 18, 2023
aa9ec65
Update wording - add "the"
dyastremsky Oct 18, 2023
e0161f4
Standarize capitalization, headings
dyastremsky Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
<!--
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

[![License](https://img.shields.io/badge/License-BSD3-lightgrey.svg)](https://opensource.org/licenses/BSD-3-Clause)

# vLLM Backend

The Triton backend for [vLLM](https://github.com/vllm-project/vllm).
You can learn more about Triton backends in the [backend
repo](https://github.com/triton-inference-server/backend). Ask
questions or report problems on the [issues
page](https://github.com/triton-inference-server/server/issues).
This backend is designed to run [vLLM](https://github.com/vllm-project/vllm)
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
with
[one of the HuggingFace models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
dyastremsky marked this conversation as resolved.
Show resolved Hide resolved
it supports.

Where can I ask general questions about Triton and Triton backends?
Be sure to read all the information below as well as the [general
Triton documentation](https://github.com/triton-inference-server/server#triton-inference-server)
available in the main [server](https://github.com/triton-inference-server/server)
repo. If you don't find your answer there you can ask questions on the
main Triton [issues page](https://github.com/triton-inference-server/server/issues).

## Build the vLLM Backend

As a Python-based backend, your Triton server just needs to have the [Python backend](https://github.com/triton-inference-server/python_backend)
located in the backends directory: `/opt/tritonserver/backends/python`. After that, you can save the vLLM backend in the backends folder as `/opt/tritonserver/backends/vllm`. The `model.py` file in the `src` directory should be in the vllm folder and will function as your Python-based backend.
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved

In other words, there are no build steps. You only need to copy this to your Triton backends repository. If you use the official Triton vLLM container, this is already set up for you.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would installing dependencies be part of build? Or do we need a seperate section on dependencies?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'll add this. I had made the assumption that this is using the vLLM backend, but we need to clarify/offer an independent build (e.g. adding these to a general Triton container).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of options, how you can build vLLM backend.
Option 1. You can follow steps described in Building With Docker and use build.py script.
The sample command will build a Triton Server container with all available options enabled:

./build.py -v --image=base,${BASE_CONTAINER_IMAGE_NAME}
                --enable-logging --enable-stats --enable-tracing
                --enable-metrics --enable-gpu-metrics --enable-cpu-metrics
                --enable-gpu
                --filesystem=gcs --filesystem=s3 --filesystem=azure_storage
                --endpoint=http --endpoint=grpc --endpoint=sagemaker --endpoint=vertex-ai
                --backend=python:r23.10
                --backend=vllm:r23.10

Option 2. You can install vLLM backend directly into our NGC Triton container. In this case, please install vllm first: pip install vllm, then set up vllm_backend in the container as follows:

mkdir -p /opt/tritonserver/backends/vllm_backend
wget -P /opt/tritonserver/backends/vllm_backend https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/src/model.py

Note: we should also mention separate container at some point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for drafting these instructions. Added!


The backend repository should look like this:
```
/opt/tritonserver/backends/
`-- vllm
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
|-- model.py
-- python
|-- libtriton_python.so
|-- triton_python_backend_stub
|-- triton_python_backend_utils.py
```

rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
## Using the vLLM Backend

You can see an example model_repository in the `samples` folder.
You can use this as is and change the model by changing the `model` value in `model.json`.
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
You can change the GPU utilization and logging parameters in that file as well.

tanmayv25 marked this conversation as resolved.
Show resolved Hide resolved
In the `samples` folder, you can also find a sample client, `client.py`.
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
This client is meant to function similarly to the Triton
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
[vLLM example](https://github.com/triton-inference-server/tutorials/tree/main/Quick_Deploy/vLLM).
By default, this will test `prompts.txt`, which we have included in the samples folder.

## Running the Latest vLLM Version

By default, the vLLM backend uses the version of vLLM that is available via Pip.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by default this is in the pre-built container?

Maybe this is best described as the vLLM version installed in the system (or container).

rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
These are compatible with the newer versions of CUDA running in Triton.
If you would like to use a specific vLLM commit or the latest version of vLLM, you
will need to use a
[custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments).
Please see the
[conda](samples/conda) subdirectory of the `samples` folder for information on how to do so.
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved

## Important Notes

* At present, Triton only supports one Python-based backend per server. If you try to start multiple vLLM models, you will get an error.

### Running Multiple Instances of Triton Server

Python-based backends use shared memory to transfer requests to the stub process. When running multiple instances of Triton Server on the same machine that use Python-based backend models, there would be shared memory region name conflicts that can result in segmentation faults or hangs. In order to avoid this issue, you need to specify different shm-region-prefix-name using the --backend-config flag.
dyastremsky marked this conversation as resolved.
Show resolved Hide resolved
```
# Triton instance 1
tritonserver --model-repository=/models --backend-config=python,shm-region-prefix-name=prefix1

# Triton instance 2
tritonserver --model-repository=/models --backend-config=python,shm-region-prefix-name=prefix2
```
Note that the hangs would only occur if the /dev/shm is shared between the two instances of the server. If you run the servers in different containers that do not share this location, you do not need to specify shm-region-prefix-name.
dyastremsky marked this conversation as resolved.
Show resolved Hide resolved
239 changes: 239 additions & 0 deletions samples/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import asyncio
import json
import queue
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import sys
from os import system
Fixed Show fixed Hide fixed

import numpy as np
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import *


class LLMClient:
def __init__(self, flags: argparse.Namespace):
self._client = grpcclient.InferenceServerClient(
url=flags.url, verbose=flags.verbose
)
self._flags = flags
self._loop = asyncio.get_event_loop()
self._results_dict = {}

async def async_request_iterator(self, prompts, sampling_parameters):
try:
for iter in range(self._flags.iterations):
for i, prompt in enumerate(prompts):
prompt_id = self._flags.offset + (len(prompts) * iter) + i
self._results_dict[str(prompt_id)] = []
yield self.create_request(
prompt,
self._flags.streaming_mode,
prompt_id,
sampling_parameters,
)
except Exception as error:
print(f"Caught an error in the request iterator: {error}")

async def stream_infer(self, prompts, sampling_parameters):
try:
# Start streaming
response_iterator = self._client.stream_infer(
inputs_iterator=self.async_request_iterator(
prompts, sampling_parameters
),
stream_timeout=self._flags.stream_timeout,
)
async for response in response_iterator:
yield response
except InferenceServerException as error:
print(error)
sys.exit(1)

async def process_stream(self, prompts, sampling_parameters):
# Clear results in between process_stream calls
self.results_dict = []

# Read response from the stream
async for response in self.stream_infer(prompts, sampling_parameters):
result, error = response
if error:
print(f"Encountered error while processing: {error}")
else:
output = result.as_numpy("TEXT")
for i in output:
self._results_dict[result.get_response().id].append(i)

async def run(self):
sampling_parameters = {"temperature": "0.1", "top_p": "0.95"}
stream = self._flags.streaming_mode
Fixed Show fixed Hide fixed
with open(self._flags.input_prompts, "r") as file:
print(f"Loading inputs from `{self._flags.input_prompts}`...")
prompts = file.readlines()

await self.process_stream(prompts, sampling_parameters)

with open(self._flags.results_file, "w") as file:
for id in self._results_dict.keys():
for result in self._results_dict[id]:
file.write(result.decode("utf-8"))
file.write("\n")
file.write("\n=========\n\n")
print(f"Storing results into `{self._flags.results_file}`...")

if self._flags.verbose:
with open(self._flags.results_file, "r") as file:
print(f"\nContents of `{self._flags.results_file}` ===>")
print(file.read())

print("PASS: vLLM example")

def run_async(self):
self._loop.run_until_complete(self.run())

def create_request(
self,
prompt,
stream,
request_id,
sampling_parameters,
send_parameters_as_tensor=True,
):
inputs = []
prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_)
try:
inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES"))
inputs[-1].set_data_from_numpy(prompt_data)
except Exception as error:
print(f"Encountered an error during request creation: {error}")

stream_data = np.array([stream], dtype=bool)
inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL"))
inputs[-1].set_data_from_numpy(stream_data)

# Request parameters are not yet supported via BLS. Provide an
# optional mechanism to send serialized parameters as an input
# tensor until support is added

if send_parameters_as_tensor:
sampling_parameters_data = np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
)
inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

# Add requested outputs
outputs = []
outputs.append(grpcclient.InferRequestedOutput("TEXT"))

# Issue the asynchronous sequence inference.
return {
"model_name": self._flags.model,
"inputs": inputs,
"outputs": outputs,
"request_id": str(request_id),
"parameters": sampling_parameters,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
type=str,
required=False,
default="vllm",
help="Model name",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
required=False,
default=False,
help="Enable verbose output",
)
parser.add_argument(
"-u",
"--url",
type=str,
required=False,
default="localhost:8001",
help="Inference server URL and its gRPC port. Default is localhost:8001.",
)
parser.add_argument(
"-t",
"--stream-timeout",
type=float,
required=False,
default=None,
help="Stream timeout in seconds. Default is None.",
)
parser.add_argument(
"--offset",
type=int,
required=False,
default=0,
help="Add offset to request IDs used",
)
parser.add_argument(
"--input-prompts",
type=str,
required=False,
default="prompts.txt",
help="Text file with input prompts",
)
parser.add_argument(
"--results-file",
type=str,
required=False,
default="results.txt",
help="The file with output results",
)
parser.add_argument(
"--iterations",
type=int,
required=False,
default=1,
help="Number of iterations through the prompts file",
)
parser.add_argument(
"-s",
"--streaming-mode",
action="store_true",
required=False,
default=False,
help="Enable streaming mode",
)
FLAGS = parser.parse_args()

client = LLMClient(FLAGS)
client.run_async()
Loading
Loading