Skip to content

Commit

Permalink
[Infer] Serving example w/ ray-serve (multiple GPU case) (#4841)
Browse files Browse the repository at this point in the history
* fix imports

* add ray-serve with Colossal-Infer tp

* trivial: send requests script

* add README

* fix worker port

* fix readme

* use app builder and autoscaling

* trivial: input args

* clean code; revise readme

* testci (skip example test)

* use auto model/tokenizer

* revert imports fix (fixed in other PRs)
  • Loading branch information
yuanheng-zhao authored Oct 2, 2023
1 parent 3a74eb4 commit 573f270
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 0 deletions.
151 changes: 151 additions & 0 deletions examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import logging
import os
from typing import Any, List, Union

import ray
import ray.util.collective as collective
import starlette
import torch
from pydantic import BaseModel
from ray import serve
from ray.serve import Application
from transformers import AutoModelForCausalLM, AutoTokenizer

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import free_port

ray_serve_logger = logging.getLogger("ray.serve")


class GenConfigArgs(BaseModel):
"""Config for generation"""

path: str
tp_size: int = 2
max_batch_size: int = 4
max_input_len: int = 128
max_output_len: int = 32


def log_cuda_info(scope_name: str):
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
ray_serve_logger.info(
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
)
if torch.cuda.is_available():
ray_serve_logger.info(
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
)
else:
ray_serve_logger.info(f" {scope_name}: cuda is not available!")


@ray.remote(num_gpus=1)
class Worker:
def __init__(self, model_path: str, tp_size: int, max_batch_size: int, max_input_len: int, max_output_len: int):
log_cuda_info("Worker.init")
self.tp_size = tp_size
self.model_path = model_path
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len

def setup(self, world_size, rank, port):
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

# Load model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)

shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
self.generate_kwargs = dict(max_new_tokens=self.max_output_len, do_sample=False)

return True

def generate(self, text: Union[str, List[str]]) -> str:
input_tokens = self.tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
ray_serve_logger.info(f"text: {text},\ninput_tokens: {input_tokens}")

model_output = self.infer_engine.generate(input_tokens, **self.generate_kwargs)
ray_serve_logger.info(f"model_output.shape: {model_output.shape}")

text_output = []
for i in range(len(model_output)):
text_output.append(self.tokenizer.decode(model_output[i]))
ray_serve_logger.info(f"output: {text_output}")

return text_output


@serve.deployment(
ray_actor_options={"num_cpus": 1, "num_gpus": 0},
max_concurrent_queries=5,
autoscaling_config={
"target_num_ongoing_requests_per_replica": 1,
"min_replicas": 1,
"initial_replicas": 1,
"max_replicas": 1,
},
)
class Driver:
def __init__(self, config: GenConfigArgs):
log_cuda_info("Driver:init")
model_path = config.path
tp_size = config.tp_size

self.num_workers = tp_size
self.workers = []
init_rets = []

# Just grab a free port on localhost
# NOTE workers in this communication group listen to the same port
available_port = free_port()

for i in range(self.num_workers):
worker_name = "worker_idx_{}".format(i)
w = Worker.options(name=worker_name).remote(
model_path, self.num_workers, config.max_batch_size, config.max_input_len, config.max_output_len
)
self.workers.append(w)
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
_options = {
"group_name": "default_driver",
"world_size": self.num_workers,
"ranks": [i for i in range(self.num_workers)],
"backend": "nccl",
}
collective.create_collective_group(self.workers, **_options)
_ = ray.get(init_rets)

# set batch wait delay in seconds and maximum number of sequences in a batch
@serve.batch(batch_wait_timeout_s=0.8, max_batch_size=4)
async def batch_generate(self, requests: List[str]):
ray_serve_logger.info(f"Driver.batch_generate: requests length: {len(requests)}\n requests: {requests}")
results = ray.get([w.generate.remote(requests) for w in self.workers])
text_res = results[0] # get any one of the copies
return text_res

async def __call__(self, request: starlette.requests.Request) -> Any:
return await self.batch_generate(request.query_params["text"])


def app(args: GenConfigArgs) -> Application:
print(args)
if args.path is None or not os.path.exists(args.path):
raise ValueError("Model path not provided or invalid path!")

return Driver.options(name="Colossal-Inference-Driver").bind(config=args)
86 changes: 86 additions & 0 deletions examples/inference/serving/ray_serve/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Colossal-Inference with Ray Serve

This example is used for demonstrating and testing the deployment of Colossal Inference from `colossalai.inference` with [Ray Serve](https://docs.ray.io/en/latest/serve/index.html). It imports inference modules from colossalai and is based on https://github.com/hpcaitech/ColossalAI/tree/a22706337a57dd1c98b95739dd09d98bd55947a0.

Single-gpu inference as well as multiple-gpu inference (i.e. tensor parallel) serving are supported.

## Installation

### Conda Environment
```bash
# create a new conda env with python 3.8
conda create -n ray_test python=3.8.18

# use torch1.13+cuda11.6
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

# install ray from wheels
pip install -U "ray[default,serve]"

# install cuda toolkit (e.g. nvcc, etc)
conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit

# install cuDNN, cuTENSOR, and NCCL
conda install -c conda-forge cupy cudnn cutensor nccl cuda-version=11.6

# install colossalai with PyTorch extensions
cd <path_to_ColossalAI_repo>
CUDA_EXT=1 pip install -e .

# install other dependencies
pip install triton==2.0.0.dev20221202
pip install transformers
```

## Launch Ray Serve and run the app
### Method #1. CLI command

Under the current directory, we could launch the app by the following command:
```bash
RAY_DEDUP_LOGS=0 serve run Colossal_Inference_rayserve:app path="PATH_TO_YOUR_MODEL_DIR"
```

By default, Ray deduplicates logs across cluster. Here we set `RAY_DEDUP_LOGS=0` to disable log deduplication, enabling each actor to log information in CLI. `serve run` runs an application from the specified import path. The formats should be `<filename>:<app_name>`.

Then we could send requests by running python script in another window:
```bash
python send_request.py
```

### Method #2. Run inside script

We could also launch ray serve and run the app inside a single script by making some modifications:
To avoid ray handler from raising error in serializing pydantic objects, we'll replace the config class from `class GenConfigArgs(BaseModel)` to
```python
from dataclasses import dataclass
@dataclass
class GenConfigArgs:
# attributes remain unchanged
```
Comment out the app builder
```python
# def app(args: GenConfigArgs) -> Application:
# ...
# return Driver.options(name="Colossal-Inference-Driver").bind(config=args)
```
And attach the following lines to the end of the file,
```python
from ray.serve.handle import DeploymentHandle, DeploymentResponse

app = Driver.bind(config=GenConfigArgs(path="<Path_to_model_dir>"))
handle: DeploymentHandle = serve.run(app).options(use_new_handle_api=True)
response: DeploymentResponse = handle.batch_generate.remote(requests="Introduce some landmarks in Beijing")
print(response.result())
```
Then we could run the script
```python
python Colossal_Inference_rayserve.py
```

### Terminate Ray Serve
Ray serve and the application would terminate automatically as you choose the second method to run any job in the script. If you choose the first method (serve run), you might want to apply `ctrl+c` to shut down the application, or use `serve shutdown` to shut down serve and deletes all applications on the ray cluster.

To make sure all the active Ray processes are killed, run
```bash
ray stop
```
15 changes: 15 additions & 0 deletions examples/inference/serving/ray_serve/send_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import ray
import requests


@ray.remote
def send_query(text):
resp = requests.get("http://localhost:8000/?text={}".format(text))
return resp.text


test_sentence = "Introduce some landmarks in Beijing"

result = ray.get(send_query.remote(test_sentence))
print("Result returned:")
print(result)
27 changes: 27 additions & 0 deletions examples/inference/serving/ray_serve/send_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ray
import requests


@ray.remote
def send_query(text):
resp = requests.get("http://localhost:8000/?text={}".format(text))
return resp.text


test_sentences = [
"Introduce some landmarks in Beijing",
"What is the weather today",
"Coding requires practice and patience",
"Rainy days inspire cozy reading",
"Laughter is contagious and heartwarming",
"Hiking mountains builds strength and resilience",
"Family bonds grow stronger with time",
"Science unlocks mysteries of the universe",
"Music soothes the soul and ignites passion",
"Artistic expression knows no boundaries",
]

results = ray.get([send_query.remote(text) for text in test_sentences])
print("Result returned:")
for res in results:
print(res)

0 comments on commit 573f270

Please sign in to comment.