Skip to content

Commit

Permalink
allow json mode on everything
Browse files Browse the repository at this point in the history
add a default json mode prompt
  • Loading branch information
devxpy committed Jul 22, 2024
1 parent 2cdd3e7 commit 8b8238b
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 171 deletions.
16 changes: 9 additions & 7 deletions daras_ai_v2/doc_search_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from daras_ai_v2.embedding_model import EmbeddingModels
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder
from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.search_ref import CitationStyles

_user_media_url_prefix = os.path.join(
Expand Down Expand Up @@ -171,8 +170,10 @@ def doc_search_advanced_settings():
st.number_input(
label="""
###### Max Snippet Words
After a document search, relevant snippets of your documents are returned as results. This setting adjusts the maximum number of words in each snippet. A high snippet size allows the LLM to access more information from your document results, at the cost of being verbose and potentially exhausting input tokens (which can cause a failure of the copilot to respond). Default: 300
After a document search, relevant snippets of your documents are returned as results.
This setting adjusts the maximum number of words in each snippet (tokens = words * 2).
A high snippet size allows the LLM to access more information from your document results, \
at the cost of being verbose and potentially exhausting input tokens (which can cause a failure of the copilot to respond).
""",
key="max_context_words",
min_value=10,
Expand All @@ -181,9 +182,10 @@ def doc_search_advanced_settings():

st.number_input(
label="""
###### Overlapping Snippet Lines
Your knowledge base documents are split into overlapping snippets. This settings adjusts how much those snippets overlap. In general you shouldn't need to adjust this. Default: 5
###### Snippet Overlap Ratio
Your knowledge base documents are split into overlapping snippets.
This settings adjusts how much those snippets overlap (overlap tokens = snippet tokens / overlap ratio).
In general you shouldn't need to adjust this.
""",
key="scroll_jump",
min_value=1,
Expand All @@ -194,7 +196,7 @@ def doc_search_advanced_settings():
def embeddings_model_selector(key: str):
return enum_selector(
EmbeddingModels,
label="##### Embeddings Model",
label="##### Embeddings Model",
key=key,
use_selectbox=True,
)
31 changes: 28 additions & 3 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
from functions.recipe_functions import LLMTools

DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible."
DEFAULT_JSON_PROMPT = (
"Please respond directly in JSON format. "
"Don't output markdown or HTML, instead print the JSON object directly without formatting."
)

CHATML_ROLE_SYSTEM = "system"
CHATML_ROLE_ASSISTANT = "assistant"
Expand Down Expand Up @@ -143,6 +147,7 @@ class LargeLanguageModels(Enum):
llm_api=LLMApis.openai,
context_window=4096,
price=1,
supports_json=True,
)
gpt_3_5_turbo_16k = LLMSpec(
label="ChatGPT 16k (openai)",
Expand Down Expand Up @@ -436,18 +441,34 @@ def run_language_model(

model: LargeLanguageModels = LargeLanguageModels[str(model)]
if model.is_chat_model:
if not messages:
if prompt and not messages:
# convert text prompt to chat messages
messages = [
{"role": "system", "content": DEFAULT_SYSTEM_MSG},
{"role": "user", "content": prompt},
format_chat_entry(role=CHATML_ROLE_SYSTEM, content=DEFAULT_SYSTEM_MSG),
format_chat_entry(role=CHATML_ROLE_USER, content=prompt),
]
if not model.is_vision_model:
# remove images from the messages
messages = [
format_chat_entry(role=entry["role"], content=get_entry_text(entry))
for entry in messages
]
if (
messages
and response_format_type == "json_object"
and "JSON" not in str(messages).upper()
):
if messages[0]["role"] != CHATML_ROLE_SYSTEM:
messages.insert(
0,
format_chat_entry(
role=CHATML_ROLE_SYSTEM, content=DEFAULT_JSON_PROMPT
),
)
else:
messages[0]["content"] = "\n\n".join(
[get_entry_text(messages[0]), DEFAULT_JSON_PROMPT]
)
entries = _run_chat_model(
api=model.llm_api,
model=model.model_id,
Expand Down Expand Up @@ -633,6 +654,7 @@ def _run_chat_model(
max_tokens=max_tokens,
temperature=temperature,
avoid_repetition=avoid_repetition,
response_format_type=response_format_type,
stop=stop,
)
case LLMApis.anthropic:
Expand Down Expand Up @@ -1030,6 +1052,7 @@ def _run_groq_chat(
temperature: float,
avoid_repetition: bool,
stop: list[str] | None,
response_format_type: ResponseFormatType | None,
):
from usage_costs.cost_utils import record_cost_auto
from usage_costs.models import ModelSku
Expand All @@ -1045,6 +1068,8 @@ def _run_groq_chat(
data["presence_penalty"] = 0.25
if stop:
data["stop"] = stop
if response_format_type:
data["response_format"] = {"type": response_format_type}
r = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
json=data,
Expand Down
110 changes: 66 additions & 44 deletions daras_ai_v2/language_model_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,57 @@
import gooey_ui as st
from pydantic import BaseModel, Field

import gooey_ui as st
from daras_ai_v2.enum_selector_widget import enum_selector, BLANK_OPTION
from daras_ai_v2.language_model import LargeLanguageModels
from daras_ai_v2.field_render import field_title_desc
from daras_ai_v2.language_model import LargeLanguageModels, ResponseFormatType


def language_model_settings(show_selector=True, show_response_format=True):
st.write("##### 🔠 Language Model Settings")
class LanguageModelSettings(BaseModel):
avoid_repetition: bool | None
num_outputs: int | None
quality: float | None
max_tokens: int | None
sampling_temperature: float | None
response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
)

selected_model = None
if show_selector:
enum_selector(
LargeLanguageModels,
label_visibility="collapsed",
key="selected_model",
use_selectbox=True,
)
selected_model = LargeLanguageModels[
st.session_state.get("selected_model") or LargeLanguageModels.gpt_4.name
]

st.checkbox("Avoid Repetition", key="avoid_repetition")
def language_model_selector(
label: str = "##### 🔠 Language Model Settings",
label_visibility: str = "visible",
key: str = "selected_model",
):
return enum_selector(
LargeLanguageModels,
label=label,
label_visibility=label_visibility,
key=key,
use_selectbox=True,
)


def language_model_settings(selected_model: str = None):
try:
llm = LargeLanguageModels[selected_model]
except KeyError:
llm = None

col1, col2 = st.columns(2)
with col1:
st.slider(
label="""
###### Answer Outputs
How many answers should the copilot generate? Additional answer outputs increase the cost of each run.
""",
key="num_outputs",
min_value=1,
max_value=4,
)
if selected_model and selected_model.is_chat_model:
st.checkbox("Avoid Repetition", key="avoid_repetition")

if not llm or llm.supports_json:
with col2:
st.slider(
label="""
###### Attempts
Generate multiple responses and choose the best one.
""",
key="quality",
min_value=1.0,
max_value=5.0,
step=0.1,
st.selectbox(
f"###### {field_title_desc(LanguageModelSettings, 'response_format_type')}",
options=[None, "json_object"],
key="response_format_type",
format_func={
None: BLANK_OPTION,
"json_object": "JSON Object",
}.__getitem__,
)

col1, col2 = st.columns(2)
Expand All @@ -68,13 +77,26 @@ def language_model_settings(show_selector=True, show_response_format=True):
max_value=2.0,
)

if show_response_format and (not selected_model or selected_model.supports_json):
st.selectbox(
f"###### Response Format",
options=[None, "json_object"],
key="response_format_type",
format_func={
None: BLANK_OPTION,
"json_object": "JSON Object",
}.__getitem__,
col1, col2 = st.columns(2)
with col1:
st.slider(
label="""
###### Answer Outputs
How many answers should the copilot generate? Additional answer outputs increase the cost of each run.
""",
key="num_outputs",
min_value=1,
max_value=4,
)
if not llm or llm.is_chat_model:
with col2:
st.slider(
label="""
###### Attempts
Generate multiple responses and choose the best one
""",
key="quality",
min_value=1.0,
max_value=5.0,
step=0.1,
)
19 changes: 9 additions & 10 deletions recipes/BulkEval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
run_language_model,
LargeLanguageModels,
)
from daras_ai_v2.language_model_settings_widgets import LanguageModelSettings
from daras_ai_v2.prompt_vars import render_prompt_vars
from recipes.BulkRunner import read_df_any, list_view_editor, del_button
from recipes.DocSearch import render_documents
Expand All @@ -48,15 +49,6 @@
]


class LLMSettingsMixin(BaseModel):
selected_model: typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
avoid_repetition: bool | None
num_outputs: int | None
quality: float | None
max_tokens: int | None
sampling_temperature: float | None


class EvalPrompt(typing.TypedDict):
name: str
prompt: str
Expand Down Expand Up @@ -168,7 +160,7 @@ def related_workflows(self) -> list:

return [BulkRunnerPage, VideoBotsPage, AsrPage, DocSearchPage]

class RequestModel(LLMSettingsMixin, BasePage.RequestModel):
class RequestModelBase(BasePage.RequestModel):
documents: list[str] = Field(
title="Input Data Spreadsheet",
description="""
Expand All @@ -193,6 +185,13 @@ class RequestModel(LLMSettingsMixin, BasePage.RequestModel):
""",
)

selected_model: (
typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None
)

class RequestModel(LanguageModelSettings, RequestModelBase):
pass

class ResponseModel(BaseModel):
output_documents: list[str]
final_prompts: list[list[str]] | None
Expand Down
21 changes: 8 additions & 13 deletions recipes/CompareLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
SUPERSCRIPT,
ResponseFormatType,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.language_model_settings_widgets import (
language_model_settings,
LanguageModelSettings,
)
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.prompt_vars import render_prompt_vars

Expand All @@ -39,22 +42,14 @@ class CompareLLMPage(BasePage):
"sampling_temperature": 0.7,
}

class RequestModel(BasePage.RequestModel):
class RequestModelBase(BasePage.RequestModel):
input_prompt: str | None
selected_models: (
list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None
)

avoid_repetition: bool | None
num_outputs: int | None
quality: float | None
max_tokens: int | None
sampling_temperature: float | None

response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
)
class RequestModel(LanguageModelSettings, RequestModelBase):
pass

class ResponseModel(BaseModel):
output_text: dict[
Expand Down Expand Up @@ -95,7 +90,7 @@ def render_usage_guide(self):
youtube_video("dhexRRDAuY8")

def render_settings(self):
language_model_settings(show_selector=False)
language_model_settings()

def run(self, state: dict) -> typing.Iterator[str | None]:
request: CompareLLMPage.RequestModel = self.RequestModel.parse_obj(state)
Expand Down
Loading

0 comments on commit 8b8238b

Please sign in to comment.