Skip to content

Commit

Permalink
Serve JSON with vLLM using FastAPI and gunicorn
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 21, 2023
1 parent 67e524d commit 8eb7ac0
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 1 deletion.
Empty file added outlines/serve/__init__.py
Empty file.
122 changes: 122 additions & 0 deletions outlines/serve/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2023 the vLLM developers

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from typing import AsyncGenerator

import uvicorn
import vllm
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from .vllm import JSONLogitsProcessor, PatchedSampler

TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
engine = None

# Patch the sampler so it is compatible with `JSONLogitsProcessor`
vllm.model_executor.layers.sampler.Sampler = PatchedSampler


@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)


@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- schema: the JSON schema to use for the generation
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)

json_schema = request_dict.pop("schema", None)
if json_schema is not None:
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] # type: ignore
else:
logits_processors = []

sampling_params = SamplingParams(
**request_dict, logits_processors=logits_processors
)
request_id = random_uuid()

results_generator = engine.generate(prompt, sampling_params, request_id) # type: ignore

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [prompt + output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stream:
return StreamingResponse(stream_results())

# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id) # type: ignore
return Response(status_code=499)
final_output = request_output

assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

# Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len`
# arguments
engine_args = AsyncEngineArgs.from_cli_args(args)

# Sets default for the model (`facebook/opt-125m`)
engine = AsyncLLMEngine.from_engine_args(engine_args)

uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)
215 changes: 215 additions & 0 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""Make vLLM compatible with Outlines' guided generation."""
import json
import math
from collections import defaultdict
from typing import DefaultDict, List, Optional

import torch
import torch.nn as nn
from vllm.model_executor.layers.sampler import (
_SAMPLING_EPS,
_apply_min_p,
_apply_penalties,
_apply_top_p_top_k,
_build_sampler_output,
_get_logits,
_get_logprobs,
_get_penalties,
_get_temperatures,
_get_top_p_top_k_min_p,
_prune_hidden_states,
_sample,
)

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


def _patched_apply_logits_processors(
logits,
sampling_metadata,
):
"""Patch vLLM's logit processor.
We need to patch the logits processor to pass the `seq_id` so we can
handle several sequences in `JSONLogitsProcessor`
"""
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(seq_id, token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


class PatchedSampler(nn.Module):
"""This code is copied from vLLM and uses the patched logits processor.
Samples the next tokens from the model's outputs.
This layer does the following:
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""

def __init__(self, vocab_size: int) -> None:
super().__init__()
self.vocab_size = vocab_size

def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata,
embedding_bias: Optional[torch.Tensor] = None,
):
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)

# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size)

# Apply logits processors (if any).
logits = _patched_apply_logits_processors(logits, sampling_metadata)
# Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
sampling_metadata
)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(
logits,
sampling_metadata,
presence_penalties,
frequency_penalties,
repetition_penalties,
)

# Apply temperature scaling.
temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))

# Apply top-p and top-k truncation.
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
sampling_metadata, self.vocab_size
)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)

do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p:
logits = _apply_min_p(logits, min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results
)
return _build_sampler_output(
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs
)


class JSONLogitsProcessor:
def __init__(self, schema, llm):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
pydantic_model
A Pydantic `BaseModel` that encodes the structure we want
the model to generate.
llm
An instance of `vllm.LLM`
"""
if isinstance(schema, dict):
schema = json.dumps(schema)
regex_str = build_regex_from_object(schema)
tokenizer = self.adapt_tokenizer(llm.tokenizer)

fsm = RegexFSM(regex_str, tokenizer)
self.fsm = fsm

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ test = [
"datasets",
"responses",
]
serve = ["vllm"]

[project.urls]
homepage = "https://github.com/outlines-dev/outlines"
Expand Down Expand Up @@ -102,14 +103,17 @@ module = [
"referencing.*",
"scipy.*",
"tiktoken.*",
"torch",
"torch.*",
"transformers.*",
"lark.*",
"interegular.*",
"datasets.*",
"numba.*",
"requests.*",
"responses.*",
"vllm.*",
"uvicorn.*",
"fastapi.*",
]
ignore_missing_imports = true

Expand Down

0 comments on commit 8eb7ac0

Please sign in to comment.