-
Notifications
You must be signed in to change notification settings - Fork 544
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Serve JSON with vLLM using FastAPI and gunicorn
- Loading branch information
Showing
4 changed files
with
342 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright 2023 the vLLM developers | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import argparse | ||
import json | ||
from typing import AsyncGenerator | ||
|
||
import uvicorn | ||
import vllm | ||
from fastapi import FastAPI, Request | ||
from fastapi.responses import JSONResponse, Response, StreamingResponse | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.utils import random_uuid | ||
|
||
from .vllm import JSONLogitsProcessor, PatchedSampler | ||
|
||
TIMEOUT_KEEP_ALIVE = 5 # seconds. | ||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. | ||
app = FastAPI() | ||
engine = None | ||
|
||
# Patch the sampler so it is compatible with `JSONLogitsProcessor` | ||
vllm.model_executor.layers.sampler.Sampler = PatchedSampler | ||
|
||
|
||
@app.get("/health") | ||
async def health() -> Response: | ||
"""Health check.""" | ||
return Response(status_code=200) | ||
|
||
|
||
@app.post("/generate") | ||
async def generate(request: Request) -> Response: | ||
"""Generate completion for the request. | ||
The request should be a JSON object with the following fields: | ||
- prompt: the prompt to use for the generation. | ||
- schema: the JSON schema to use for the generation | ||
- stream: whether to stream the results or not. | ||
- other fields: the sampling parameters (See `SamplingParams` for details). | ||
""" | ||
request_dict = await request.json() | ||
prompt = request_dict.pop("prompt") | ||
stream = request_dict.pop("stream", False) | ||
|
||
json_schema = request_dict.pop("schema", None) | ||
if json_schema is not None: | ||
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] # type: ignore | ||
else: | ||
logits_processors = [] | ||
|
||
sampling_params = SamplingParams( | ||
**request_dict, logits_processors=logits_processors | ||
) | ||
request_id = random_uuid() | ||
|
||
results_generator = engine.generate(prompt, sampling_params, request_id) # type: ignore | ||
|
||
# Streaming case | ||
async def stream_results() -> AsyncGenerator[bytes, None]: | ||
async for request_output in results_generator: | ||
prompt = request_output.prompt | ||
text_outputs = [prompt + output.text for output in request_output.outputs] | ||
ret = {"text": text_outputs} | ||
yield (json.dumps(ret) + "\0").encode("utf-8") | ||
|
||
if stream: | ||
return StreamingResponse(stream_results()) | ||
|
||
# Non-streaming case | ||
final_output = None | ||
async for request_output in results_generator: | ||
if await request.is_disconnected(): | ||
# Abort the request if the client disconnects. | ||
await engine.abort(request_id) # type: ignore | ||
return Response(status_code=499) | ||
final_output = request_output | ||
|
||
assert final_output is not None | ||
prompt = final_output.prompt | ||
text_outputs = [prompt + output.text for output in final_output.outputs] | ||
ret = {"text": text_outputs} | ||
return JSONResponse(ret) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--host", type=str, default=None) | ||
parser.add_argument("--port", type=int, default=8000) | ||
parser.add_argument("--ssl-keyfile", type=str, default=None) | ||
parser.add_argument("--ssl-certfile", type=str, default=None) | ||
parser = AsyncEngineArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
|
||
# Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len` | ||
# arguments | ||
engine_args = AsyncEngineArgs.from_cli_args(args) | ||
|
||
# Sets default for the model (`facebook/opt-125m`) | ||
engine = AsyncLLMEngine.from_engine_args(engine_args) | ||
|
||
uvicorn.run( | ||
app, | ||
host=args.host, | ||
port=args.port, | ||
log_level="debug", | ||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, | ||
ssl_keyfile=args.ssl_keyfile, | ||
ssl_certfile=args.ssl_certfile, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
"""Make vLLM compatible with Outlines' guided generation.""" | ||
import json | ||
import math | ||
from collections import defaultdict | ||
from typing import DefaultDict, List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from vllm.model_executor.layers.sampler import ( | ||
_SAMPLING_EPS, | ||
_apply_min_p, | ||
_apply_penalties, | ||
_apply_top_p_top_k, | ||
_build_sampler_output, | ||
_get_logits, | ||
_get_logprobs, | ||
_get_penalties, | ||
_get_temperatures, | ||
_get_top_p_top_k_min_p, | ||
_prune_hidden_states, | ||
_sample, | ||
) | ||
|
||
from outlines.fsm.fsm import RegexFSM | ||
from outlines.fsm.json_schema import build_regex_from_object | ||
|
||
|
||
def _patched_apply_logits_processors( | ||
logits, | ||
sampling_metadata, | ||
): | ||
"""Patch vLLM's logit processor. | ||
We need to patch the logits processor to pass the `seq_id` so we can | ||
handle several sequences in `JSONLogitsProcessor` | ||
""" | ||
logits_row_idx = 0 | ||
found_logits_processors = False | ||
for seq_ids, sampling_params in sampling_metadata.seq_groups: | ||
logits_processors = sampling_params.logits_processors | ||
if logits_processors: | ||
found_logits_processors = True | ||
for seq_id in seq_ids: | ||
logits_row = logits[logits_row_idx] | ||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids | ||
for logits_processor in logits_processors: | ||
logits_row = logits_processor(seq_id, token_ids, logits_row) | ||
logits[logits_row_idx] = logits_row | ||
logits_row_idx += 1 | ||
else: | ||
logits_row_idx += len(seq_ids) | ||
if found_logits_processors: | ||
assert logits_row_idx == logits.shape[0] | ||
return logits | ||
|
||
|
||
class PatchedSampler(nn.Module): | ||
"""This code is copied from vLLM and uses the patched logits processor. | ||
Samples the next tokens from the model's outputs. | ||
This layer does the following: | ||
1. Discard the hidden states that are not used for sampling (i.e., all | ||
tokens except the final one in each prompt). | ||
2. Compute the logits for the next tokens. | ||
3. Apply presence, frequency and repetition penalties. | ||
4. Apply temperature scaling. | ||
5. Apply top-p and top-k truncation. | ||
6. Sample the next tokens. | ||
Here, each sequence group within the batch can have different sampling | ||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). | ||
""" | ||
|
||
def __init__(self, vocab_size: int) -> None: | ||
super().__init__() | ||
self.vocab_size = vocab_size | ||
|
||
def forward( | ||
self, | ||
embedding: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
sampling_metadata, | ||
embedding_bias: Optional[torch.Tensor] = None, | ||
): | ||
# Get the hidden states that we use for sampling. | ||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) | ||
|
||
# Get the logits for the next tokens. | ||
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) | ||
|
||
# Apply logits processors (if any). | ||
logits = _patched_apply_logits_processors(logits, sampling_metadata) | ||
# Apply presence and frequency penalties. | ||
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties( | ||
sampling_metadata | ||
) | ||
assert len(presence_penalties) == logits.shape[0] | ||
assert len(frequency_penalties) == logits.shape[0] | ||
assert len(repetition_penalties) == logits.shape[0] | ||
logits = _apply_penalties( | ||
logits, | ||
sampling_metadata, | ||
presence_penalties, | ||
frequency_penalties, | ||
repetition_penalties, | ||
) | ||
|
||
# Apply temperature scaling. | ||
temperatures = _get_temperatures(sampling_metadata) | ||
assert len(temperatures) == logits.shape[0] | ||
if any(t != 1.0 for t in temperatures): | ||
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) | ||
# Use in-place division to avoid creating a new tensor. | ||
logits.div_(t.unsqueeze(dim=1)) | ||
|
||
# Apply top-p and top-k truncation. | ||
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( | ||
sampling_metadata, self.vocab_size | ||
) | ||
assert len(top_ps) == len(top_ks) == logits.shape[0] | ||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) | ||
do_top_k = any(k != self.vocab_size for k in top_ks) | ||
if do_top_p or do_top_k: | ||
logits = _apply_top_p_top_k(logits, top_ps, top_ks) | ||
|
||
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) | ||
if do_min_p: | ||
logits = _apply_min_p(logits, min_ps) | ||
|
||
# We use float32 for probabilities and log probabilities. | ||
# Compute the probabilities. | ||
probs = torch.softmax(logits, dim=-1, dtype=torch.float) | ||
# Compute the log probabilities. | ||
# Use log_softmax to ensure numerical stability. | ||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) | ||
|
||
# Sample the next tokens. | ||
sample_results = _sample(probs, logprobs, sampling_metadata) | ||
# Get the logprobs query results. | ||
prompt_logprobs, sample_logprobs = _get_logprobs( | ||
logprobs, sampling_metadata, sample_results | ||
) | ||
return _build_sampler_output( | ||
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs | ||
) | ||
|
||
|
||
class JSONLogitsProcessor: | ||
def __init__(self, schema, llm): | ||
"""Compile the FSM that drives the JSON-guided generation. | ||
Parameters | ||
---------- | ||
pydantic_model | ||
A Pydantic `BaseModel` that encodes the structure we want | ||
the model to generate. | ||
llm | ||
An instance of `vllm.LLM` | ||
""" | ||
if isinstance(schema, dict): | ||
schema = json.dumps(schema) | ||
regex_str = build_regex_from_object(schema) | ||
tokenizer = self.adapt_tokenizer(llm.tokenizer) | ||
|
||
fsm = RegexFSM(regex_str, tokenizer) | ||
self.fsm = fsm | ||
|
||
def __call__( | ||
self, seq_id: int, input_ids: List[int], scores: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Use the FSM to bias the logits before sampling the next token.""" | ||
|
||
if len(input_ids) == 0: # Initialize the fsm states | ||
self.fsm_state: DefaultDict[int, int] = defaultdict(int) | ||
else: | ||
last_token = input_ids[-1] | ||
self.fsm_state[seq_id] = self.fsm.next_state( | ||
self.fsm_state[seq_id], last_token | ||
) | ||
|
||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) | ||
|
||
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) | ||
mask[allowed_tokens] = 0 | ||
biased_scores = scores + mask | ||
|
||
return biased_scores | ||
|
||
def adapt_tokenizer(self, tokenizer): | ||
"""Adapt vLLM's tokenizer to use to compile the FSM. | ||
The API of Outlines tokenizers is slightly different to that of | ||
`transformers`. In addition we need to handle the missing spaces to | ||
Llama's tokenizer to be able to compile FSMs for this model. | ||
""" | ||
tokenizer.vocabulary = tokenizer.get_vocab() | ||
tokenizer.special_tokens = set(tokenizer.all_special_tokens) | ||
|
||
def convert_token_to_string(token: str) -> str: | ||
from transformers.file_utils import SPIECE_UNDERLINE | ||
|
||
string = tokenizer.convert_tokens_to_string([token]) | ||
|
||
# A hack to handle missing spaces to HF's Llama tokenizers | ||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": | ||
return " " + string | ||
|
||
return string | ||
|
||
tokenizer.convert_token_to_string = convert_token_to_string | ||
|
||
return tokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters