Skip to content

Commit

Permalink
Reraise Jinja2 TemplateError (#840)
Browse files Browse the repository at this point in the history
* Handle jinja2.TemplateError in gateway

* Raise GatewayError in openai interface
  • Loading branch information
Egor-S authored Jan 18, 2024
1 parent 3121673 commit afa3ccc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
27 changes: 22 additions & 5 deletions gateway/src/dstack/gateway/openai/clients/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import httpx
import jinja2
import jinja2.sandbox

from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.clients import ChatCompletionsClient
Expand All @@ -31,9 +32,17 @@ def __init__(
headers={} if host is None else {"Host": host},
timeout=60,
)
self.chat_template = jinja2.Template(chat_template)
self.eos_token = eos_token

try:
jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
jinja_env.globals["raise_exception"] = raise_exception
self.chat_template = jinja_env.from_string(chat_template)
except jinja2.TemplateError as e:
raise GatewayError(f"Failed to compile chat template: {e}")

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
payload = self.get_payload(request)
resp = await self.client.post("/generate", json=payload)
Expand Down Expand Up @@ -123,10 +132,14 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
yield chunk

def get_payload(self, request: ChatCompletionsRequest) -> Dict:
inputs = self.chat_template.render(
messages=request.messages,
add_generation_prompt=True,
)
try:
inputs = self.chat_template.render(
messages=request.messages,
add_generation_prompt=True,
)
except jinja2.TemplateError as e:
raise GatewayError(f"Failed to render chat template: {e}")

stop = ([request.stop] if isinstance(request.stop, str) else request.stop) or []
if self.eos_token not in stop:
stop.append(self.eos_token)
Expand Down Expand Up @@ -178,3 +191,7 @@ def __del__(self):
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass


def raise_exception(message: str):
raise jinja2.TemplateError(message)
20 changes: 12 additions & 8 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse

from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.schemas import (
ChatCompletionsChunk,
ChatCompletionsRequest,
Expand All @@ -25,14 +26,17 @@ async def get_models(
async def post_chat_completions(
project: str, body: ChatCompletionsRequest, store: Annotated[OpenAIStore, Depends(get_store)]
):
client = await store.get_chat_client(project, body.model)
if not body.stream:
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(client.stream(body)),
media_type="text/event-stream",
)
try:
client = await store.get_chat_client(project, body.model)
if not body.stream:
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(client.stream(body)),
media_type="text/event-stream",
)
except GatewayError as e:
raise e.http()


async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
Expand Down

0 comments on commit afa3ccc

Please sign in to comment.