Skip to content

Commit

Permalink
compare llm support for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Jan 31, 2024
1 parent d79ab39 commit 691bce5
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 30 deletions.
151 changes: 126 additions & 25 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from django.conf import settings
from jinja2.lexer import whitespace_re
from loguru import logger
from openai.types.chat import ChatCompletionContentPartParam
from openai import Stream
from openai.types.chat import (
ChatCompletionContentPartParam,
ChatCompletionChunk,
)

from daras_ai_v2.asr import get_google_auth_session
from daras_ai_v2.exceptions import raise_for_status
Expand All @@ -27,7 +31,10 @@
from daras_ai_v2.redis_cache import (
get_redis_cache,
)
from daras_ai_v2.text_splitter import default_length_function
from daras_ai_v2.text_splitter import (
default_length_function,
default_separators,
)

DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible."

Expand Down Expand Up @@ -325,19 +332,23 @@ def run_language_model(
stop: list[str] = None,
avoid_repetition: bool = False,
tools: list[LLMTools] = None,
stream: bool = False,
response_format_type: typing.Literal["text", "json_object"] = None,
) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]:
) -> (
list[str]
| tuple[list[str], list[list[dict]]]
| typing.Generator[list[str], None, None]
):
assert bool(prompt) != bool(
messages
), "Pleave provide exactly one of { prompt, messages }"

model: LargeLanguageModels = LargeLanguageModels[str(model)]
api = llm_api[model]
model_name = llm_model_names[model]
is_chatml = False
if model.is_chat_model():
if messages:
is_chatml = False
else:
if not messages:
# if input is chatml, convert it into json messages
is_chatml, messages = parse_chatml(prompt) # type: ignore
messages = messages or []
Expand All @@ -347,7 +358,7 @@ def run_language_model(
format_chat_entry(role=entry["role"], content=get_entry_text(entry))
for entry in messages
]
result = _run_chat_model(
entries = _run_chat_model(
api=api,
model=model_name,
messages=messages, # type: ignore
Expand All @@ -358,26 +369,18 @@ def run_language_model(
avoid_repetition=avoid_repetition,
tools=tools,
response_format_type=response_format_type,
# we can't stream with tools or json yet
stream=stream and not (tools or response_format_type),
)
if response_format_type == "json_object":
out_content = [json.loads(entry["content"]) for entry in result]
if stream:
return _stream_llm_outputs(entries, is_chatml, response_format_type, tools)
else:
out_content = [
# return messages back as either chatml or json messages
format_chatml_message(entry)
if is_chatml
else (entry.get("content") or "").strip()
for entry in result
]
if tools:
return out_content, [(entry.get("tool_calls") or []) for entry in result]
else:
return out_content
return _parse_entries(entries, is_chatml, response_format_type, tools)
else:
if tools:
raise ValueError("Only OpenAI chat models support Tools")
logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}")
result = _run_text_model(
msgs = _run_text_model(
api=api,
model=model_name,
prompt=prompt,
Expand All @@ -388,7 +391,41 @@ def run_language_model(
avoid_repetition=avoid_repetition,
quality=quality,
)
return [msg.strip() for msg in result]
ret = [msg.strip() for msg in msgs]
if stream:
ret = [ret]
return ret


def _stream_llm_outputs(result, is_chatml, response_format_type, tools):
if isinstance(result, list): # compatibility with non-streaming apis
result = [result]
for entries in result:
yield _parse_entries(entries, is_chatml, response_format_type, tools)


def _parse_entries(
entries: list[dict],
is_chatml: bool,
response_format_type: typing.Literal["text", "json_object"] | None,
tools: list[dict] | None,
):
if response_format_type == "json_object":
ret = [json.loads(entry["content"]) for entry in entries]
else:
ret = [
# return messages back as either chatml or json messages
(
format_chatml_message(entry)
if is_chatml
else (entry.get("content") or "").strip()
)
for entry in entries
]
if tools:
return ret, [(entry.get("tool_calls") or []) for entry in entries]
else:
return ret


def _run_text_model(
Expand Down Expand Up @@ -439,7 +476,8 @@ def _run_chat_model(
avoid_repetition: bool,
tools: list[LLMTools] | None,
response_format_type: typing.Literal["text", "json_object"] | None,
) -> list[ConversationEntry]:
stream: bool = False,
) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]:
match api:
case LLMApis.openai:
return _run_openai_chat(
Expand All @@ -452,6 +490,7 @@ def _run_chat_model(
temperature=temperature,
tools=tools,
response_format_type=response_format_type,
stream=stream,
)
case LLMApis.vertex_ai:
if tools:
Expand Down Expand Up @@ -490,7 +529,8 @@ def _run_openai_chat(
avoid_repetition: bool,
tools: list[LLMTools] | None,
response_format_type: typing.Literal["text", "json_object"] | None,
) -> list[ConversationEntry]:
stream: bool = False,
) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]:
from openai._types import NOT_GIVEN

if avoid_repetition:
Expand All @@ -517,11 +557,72 @@ def _run_openai_chat(
response_format={"type": response_format_type}
if response_format_type
else NOT_GIVEN,
stream=stream,
)
for model_str in model
],
)
return [choice.message.dict() for choice in r.choices]
if stream:
return _stream_openai_chunked(r)
else:
return [choice.message.dict() for choice in r.choices]


def _stream_openai_chunked(
r: Stream[ChatCompletionChunk],
start_chunk_size: int = 50,
stop_chunk_size: int = 400,
step_chunk_size: int = 150,
):
ret = []
chunk_size = start_chunk_size

for completion_chunk in r:
changed = False
for choice in completion_chunk.choices:
try:
entry = ret[choice.index]
except IndexError:
# initialize the entry
entry = choice.delta.dict() | {"content": "", "chunk": ""}
ret.append(entry)

# append the delta to the current chunk
if not choice.delta.content:
continue
entry["chunk"] += choice.delta.content
# if the chunk is too small, we need to wait for more data
chunk = entry["chunk"]
if len(chunk) < chunk_size:
continue

# iterate through the separators and find the best one that matches
for sep in default_separators[:-1]:
# find the last occurrence of the separator
match = None
for match in re.finditer(sep, chunk):
pass
if not match:
continue # no match, try the next separator or wait for more data
# append text before the separator to the content
part = chunk[: match.end()]
if len(part) < chunk_size:
continue # not enough text, try the next separator or wait for more data
entry["content"] += part
# set text after the separator as the next chunk
entry["chunk"] = chunk[match.end() :]
# increase the chunk size, but don't go over the max
chunk_size = min(chunk_size + step_chunk_size, stop_chunk_size)
# we found a separator, so we can stop looking and yield the partial result
changed = True
break
if changed:
yield ret

# add the leftover chunks
for entry in ret:
entry["content"] += entry["chunk"]
yield ret


@retry_if(openai_should_retry)
Expand Down
13 changes: 8 additions & 5 deletions recipes/CompareLLM.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import random
import typing


import gooey_ui as st
from pydantic import BaseModel

import gooey_ui as st
from bots.models import Workflow
from daras_ai_v2.base import BasePage
from daras_ai_v2.enum_selector_widget import enum_multiselect
Expand Down Expand Up @@ -94,17 +93,21 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
state["output_text"] = output_text = {}

for selected_model in request.selected_models:
yield f"Running {LargeLanguageModels[selected_model].value}..."

output_text[selected_model] = run_language_model(
model = LargeLanguageModels[selected_model]
yield f"Running {model.value}..."
ret = run_language_model(
model=selected_model,
quality=request.quality,
num_outputs=request.num_outputs,
temperature=request.sampling_temperature,
prompt=prompt,
max_tokens=request.max_tokens,
avoid_repetition=request.avoid_repetition,
stream=True,
)
for i, item in enumerate(ret):
output_text[selected_model] = item
yield f"Streaming {model.value}... {i + 1}"

def render_output(self):
self._render_outputs(st.session_state, 450)
Expand Down

0 comments on commit 691bce5

Please sign in to comment.