diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 7b3bcd218..b4e882808 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -581,8 +581,7 @@ def run_language_model( if not model.is_vision_model: # remove images from the messages messages = [ - format_chat_entry(role=entry["role"], content=get_entry_text(entry)) - for entry in messages + entry | dict(content=get_entry_text(entry)) for entry in messages ] if ( messages @@ -781,12 +780,11 @@ def _run_chat_model( temperature=temperature, ) case LLMApis.groq: - if tools: - raise ValueError("Only OpenAI chat models support Tools") return _run_groq_chat( model=model, messages=messages, max_tokens=max_tokens, + tools=tools, temperature=temperature, avoid_repetition=avoid_repetition, response_format_type=response_format_type, @@ -1250,6 +1248,7 @@ def _run_groq_chat( model: str, messages: list[ConversationEntry], max_tokens: int, + tools: list[LLMTool] | None, temperature: float, avoid_repetition: bool, stop: list[str] | None, @@ -1264,6 +1263,8 @@ def _run_groq_chat( "max_tokens": max_tokens, "temperature": temperature, } + if tools: + data["tools"] = [tool.spec for tool in tools] if avoid_repetition: data["frequency_penalty"] = 0.1 data["presence_penalty"] = 0.25