Skip to content

Commit

Permalink
added json format to doc_search, doc_summary, and google_gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Jul 15, 2024
1 parent eb6d1c3 commit d9a05e7
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 32 deletions.
5 changes: 5 additions & 0 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class LLMSpec(typing.NamedTuple):
is_chat_model: bool = True
is_vision_model: bool = False
is_deprecated: bool = False
supports_json: bool = False


class LargeLanguageModels(Enum):
Expand All @@ -75,6 +76,7 @@ class LargeLanguageModels(Enum):
context_window=128_000,
price=10,
is_vision_model=True,
supports_json=True,
)
# https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4
gpt_4_turbo_vision = LLMSpec(
Expand All @@ -87,6 +89,7 @@ class LargeLanguageModels(Enum):
context_window=128_000,
price=6,
is_vision_model=True,
supports_json=True,
)
gpt_4_vision = LLMSpec(
label="GPT-4 Vision (openai) 🔻",
Expand All @@ -104,6 +107,7 @@ class LargeLanguageModels(Enum):
llm_api=LLMApis.openai,
context_window=128_000,
price=5,
supports_json=True,
)

# https://platform.openai.com/docs/models/gpt-4
Expand Down Expand Up @@ -327,6 +331,7 @@ def __init__(self, *args):
self.is_deprecated = spec.is_deprecated
self.is_chat_model = spec.is_chat_model
self.is_vision_model = spec.is_vision_model
self.supports_json = spec.supports_json

@property
def value(self):
Expand Down
26 changes: 18 additions & 8 deletions daras_ai_v2/language_model_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import gooey_ui as st

from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.enum_selector_widget import enum_selector, BLANK_OPTION
from daras_ai_v2.language_model import LargeLanguageModels


def language_model_settings(show_selector=True):
def language_model_settings(show_selector=True, show_response_format=True):
st.write("##### 🔠 Language Model Settings")

selected_model = None
if show_selector:
enum_selector(
LargeLanguageModels,
label_visibility="collapsed",
key="selected_model",
use_selectbox=True,
)
selected_model = LargeLanguageModels[
st.session_state.get("selected_model") or LargeLanguageModels.gpt_4.name
]

st.checkbox("Avoid Repetition", key="avoid_repetition")

Expand All @@ -28,12 +32,7 @@ def language_model_settings(show_selector=True):
min_value=1,
max_value=4,
)
if (
show_selector
and not LargeLanguageModels[
st.session_state.get("selected_model") or LargeLanguageModels.gpt_4.name
].is_chat_model
):
if selected_model and selected_model.is_chat_model:
with col2:
st.slider(
label="""
Expand Down Expand Up @@ -68,3 +67,14 @@ def language_model_settings(show_selector=True):
min_value=0.0,
max_value=2.0,
)

if show_response_format and (not selected_model or selected_model.supports_json):
st.selectbox(
f"###### Response Format",
options=[None, "json_object"],
key="response_format_type",
format_func={
None: BLANK_OPTION,
"json_object": "JSON Object",
}.__getitem__,
)
14 changes: 2 additions & 12 deletions recipes/CompareLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import gooey_ui as st
from bots.models import Workflow
from daras_ai_v2.base import BasePage
from daras_ai_v2.enum_selector_widget import enum_multiselect, BLANK_OPTION
from daras_ai_v2.field_render import field_title
from daras_ai_v2.enum_selector_widget import enum_multiselect
from daras_ai_v2.language_model import (
run_language_model,
LargeLanguageModels,
Expand All @@ -17,7 +16,7 @@
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars
from daras_ai_v2.prompt_vars import render_prompt_vars

DEFAULT_COMPARE_LM_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/fef06d86-1f70-11ef-b8ee-02420a00015b/LLMs.jpg"

Expand Down Expand Up @@ -95,15 +94,6 @@ def render_usage_guide(self):

def render_settings(self):
language_model_settings(show_selector=False)
st.selectbox(
f"###### {field_title(self.RequestModel, 'response_format_type')}",
options=[None, "json_object"],
key="response_format_type",
format_func={
None: BLANK_OPTION,
"json_object": "JSON Object",
}.__getitem__,
)

def run(self, state: dict) -> typing.Iterator[str | None]:
request: CompareLLMPage.RequestModel = self.RequestModel.parse_obj(state)
Expand Down
7 changes: 5 additions & 2 deletions recipes/DocExtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
flatapply_parallel,
)
from daras_ai_v2.gdrive_downloader import is_gdrive_url, gdrive_download
from daras_ai_v2.language_model import run_language_model, LargeLanguageModels
from daras_ai_v2.language_model import (
run_language_model,
LargeLanguageModels,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.settings import service_account_key_path
Expand Down Expand Up @@ -138,7 +141,7 @@ def render_settings(self):
key="task_instructions",
height=300,
)
language_model_settings()
language_model_settings(show_response_format=False)

enum_selector(AsrModels, label="##### ASR Model", key="selected_asr_model")
st.write("---")
Expand Down
11 changes: 9 additions & 2 deletions recipes/DocSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing

from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_ui as st
from bots.models import Workflow
Expand All @@ -19,10 +19,11 @@
from daras_ai_v2.language_model import (
run_language_model,
LargeLanguageModels,
ResponseFormatType,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars
from daras_ai_v2.prompt_vars import render_prompt_vars
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.search_ref import (
SearchReference,
Expand Down Expand Up @@ -76,6 +77,11 @@ class RequestModel(DocSearchRequest, BasePage.RequestModel):
max_tokens: int | None
sampling_temperature: float | None

response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
)

citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None

class ResponseModel(BaseModel):
Expand Down Expand Up @@ -202,6 +208,7 @@ def run_v2(
prompt=response.final_prompt,
max_tokens=request.max_tokens,
avoid_repetition=request.avoid_repetition,
response_format_type=request.response_format_type,
)

citation_style = (
Expand Down
9 changes: 8 additions & 1 deletion recipes/DocSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum

from daras_ai_v2.pydantic_validation import FieldHttpUrl
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_ui as st
from bots.models import Workflow
Expand All @@ -16,6 +16,7 @@
LargeLanguageModels,
run_language_model,
calc_gpt_tokens,
ResponseFormatType,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.pt import PromptTree
Expand Down Expand Up @@ -72,6 +73,11 @@ class RequestModel(BasePage.RequestModel):
max_tokens: int | None
sampling_temperature: float | None

response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
)

chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None

selected_asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None
Expand Down Expand Up @@ -240,6 +246,7 @@ def llm(p: str) -> str:
num_outputs=request.num_outputs,
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
response_format_type=request.response_format_type,
)[0]

state["prompt_tree"] = prompt_tree = []
Expand Down
11 changes: 9 additions & 2 deletions recipes/GoogleGPT.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_ui as st
from bots.models import Workflow
Expand All @@ -14,10 +14,11 @@
from daras_ai_v2.language_model import (
run_language_model,
LargeLanguageModels,
ResponseFormatType,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.prompt_vars import render_prompt_vars, variables_input
from daras_ai_v2.prompt_vars import render_prompt_vars
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.search_ref import (
SearchReference,
Expand Down Expand Up @@ -89,6 +90,11 @@ class RequestModel(GoogleSearchMixin, BasePage.RequestModel):
max_tokens: int | None
sampling_temperature: float | None

response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
)

max_search_urls: int | None

max_references: int | None
Expand Down Expand Up @@ -279,4 +285,5 @@ def run_v2(
prompt=response.final_prompt,
max_tokens=request.max_tokens,
avoid_repetition=request.avoid_repetition,
response_format_type=request.response_format_type,
)
2 changes: 1 addition & 1 deletion recipes/SEOSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def render_settings(self):
st.checkbox("Enable Internal Cross-Linking", key="enable_crosslinks")
st.checkbox("Enable HTML Formatting", key="enable_html")

language_model_settings()
language_model_settings(show_response_format=False)

st.write("---")

Expand Down
4 changes: 2 additions & 2 deletions recipes/SmartGPT.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

import jinja2.sandbox
from pydantic import BaseModel
from pydantic import BaseModel, Field

import gooey_ui as st
from bots.models import Workflow
Expand Down Expand Up @@ -80,7 +80,7 @@ def render_settings(self):
""",
key="dera_prompt",
)
language_model_settings()
language_model_settings(show_response_format=False)

def related_workflows(self):
from recipes.CompareLLM import CompareLLMPage
Expand Down
2 changes: 1 addition & 1 deletion recipes/SocialLookupEmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def render_form_v2(self):
)

def render_settings(self):
language_model_settings()
language_model_settings(show_response_format=False)

# st.text_input("URL 1", key="url1")
# st.text_input("URL 2", key="url2")
Expand Down
2 changes: 1 addition & 1 deletion recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def render_settings(self):
doc_search_advanced_settings()
st.write("---")

language_model_settings(show_selector=False)
language_model_settings(show_selector=False, show_response_format=False)

st.write("---")

Expand Down

0 comments on commit d9a05e7

Please sign in to comment.