Skip to content

Commit

Permalink
Merge pull request #558 from GooeyAI/support-o1
Browse files Browse the repository at this point in the history
Add support for o1-preview & o1-mini
  • Loading branch information
nikochiko authored Dec 12, 2024
2 parents 4589033 + afb4939 commit 1e0ba50
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 114 deletions.
44 changes: 36 additions & 8 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,35 @@ class LLMSpec(typing.NamedTuple):
is_vision_model: bool = False
is_deprecated: bool = False
supports_json: bool = False
supports_temperature: bool = True


class LargeLanguageModels(Enum):
# https://platform.openai.com/docs/models/gpt-4o
# https://platform.openai.com/docs/models#o1
o1_preview = LLMSpec(
label="o1-preview (openai)",
model_id="o1-preview-2024-09-12",
llm_api=LLMApis.openai,
context_window=128_000,
price=50,
is_vision_model=False,
supports_json=False,
supports_temperature=False,
)

# https://platform.openai.com/docs/models#o1
o1_mini = LLMSpec(
label="o1-mini (openai)",
model_id="o1-mini-2024-09-12",
llm_api=LLMApis.openai,
context_window=128_000,
price=13,
is_vision_model=False,
supports_json=False,
supports_temperature=False,
)

# https://platform.openai.com/docs/models#gpt-4o
gpt_4_o = LLMSpec(
label="GPT-4o (openai)",
model_id="gpt-4o-2024-08-06",
Expand All @@ -84,7 +109,7 @@ class LargeLanguageModels(Enum):
is_vision_model=True,
supports_json=True,
)
# https://platform.openai.com/docs/models/gpt-4o-mini
# https://platform.openai.com/docs/models#gpt-4o-mini
gpt_4_o_mini = LLMSpec(
label="GPT-4o-mini (openai)",
model_id="gpt-4o-mini",
Expand Down Expand Up @@ -496,6 +521,7 @@ def __init__(self, *args):
self.is_chat_model = spec.is_chat_model
self.is_vision_model = spec.is_vision_model
self.supports_json = spec.supports_json
self.supports_temperature = spec.supports_temperature

@property
def value(self):
Expand Down Expand Up @@ -599,6 +625,8 @@ def run_language_model(
messages[0]["content"] = "\n\n".join(
[get_entry_text(messages[0]), DEFAULT_JSON_PROMPT]
)
if not model.supports_temperature:
temperature = None
result = _run_chat_model(
api=model.llm_api,
model=model.model_id,
Expand Down Expand Up @@ -735,7 +763,7 @@ def _run_chat_model(
messages: list[ConversationEntry],
max_tokens: int,
num_outputs: int,
temperature: float,
temperature: float | None,
stop: list[str] | None,
avoid_repetition: bool,
tools: list[LLMTool] | None,
Expand All @@ -750,7 +778,7 @@ def _run_chat_model(
return _run_openai_chat(
model=model,
avoid_repetition=avoid_repetition,
max_tokens=max_tokens,
max_completion_tokens=max_tokens,
messages=messages,
num_outputs=num_outputs,
stop=stop,
Expand Down Expand Up @@ -1003,9 +1031,9 @@ def _run_openai_chat(
*,
model: str,
messages: list[ConversationEntry],
max_tokens: int,
max_completion_tokens: int,
num_outputs: int,
temperature: float,
temperature: float | None = None,
stop: list[str] | None = None,
avoid_repetition: bool = False,
tools: list[LLMTool] | None = None,
Expand All @@ -1027,10 +1055,10 @@ def _run_openai_chat(
_get_chat_completions_create(
model=model_str,
messages=messages,
max_tokens=max_tokens,
max_completion_tokens=max_completion_tokens,
stop=stop or NOT_GIVEN,
n=num_outputs,
temperature=temperature,
temperature=temperature if temperature is not None else NOT_GIVEN,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tools=[tool.spec for tool in tools] if tools else NOT_GIVEN,
Expand Down
Loading

0 comments on commit 1e0ba50

Please sign in to comment.