Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tool support for ollama #164

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 146 additions & 20 deletions lib/chat_models/chat_ollama_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
alias LangChain.ChatModels.ChatModel
alias LangChain.ChatModels.ChatOpenAI
alias LangChain.Message
alias LangChain.Message.ToolCall
alias LangChain.Message.ToolResult
alias LangChain.MessageDelta
alias LangChain.Function
alias LangChain.FunctionParam
alias LangChain.LangChainError
alias LangChain.Utils

Expand Down Expand Up @@ -186,14 +190,24 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
|> validate_number(:mirostat_eta, greater_than_or_equal_to: 0.0, less_than_or_equal_to: 1.0)
end

defp messages_for_api(messages) do
Enum.reduce(messages, [], fn m, acc ->
case for_api(m) do
data when is_map(data) -> [data | acc]
data when is_list(data) -> Enum.reverse(data) ++ acc
end
end)
|> Enum.reverse()
end

@doc """
Return the params formatted for an API request.
"""
def for_api(%ChatOllamaAI{} = model, messages, _functions) do
def for_api(%ChatOllamaAI{} = model, messages, tools) do
%{
model: model.model,
temperature: model.temperature,
messages: messages |> Enum.map(&ChatOpenAI.for_api/1),
messages: messages_for_api(messages),
stream: model.stream,
seed: model.seed,
num_ctx: model.num_ctx,
Expand All @@ -212,15 +226,94 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
top_k: model.top_k,
top_p: model.top_p
}
|> Utils.conditionally_add_to_map(:tools, get_tools_for_api(tools))
end

def for_api(%Message{role: :assistant, tool_calls: tool_calls} = msg)
when is_list(tool_calls) do
%{
"role" => :assistant,
"content" => msg.content
}
|> Utils.conditionally_add_to_map("tool_calls", Enum.map(tool_calls, &for_api(&1)))
end

# ToolCall support
def for_api(%ToolCall{type: :function} = fun) do
%{
"id" => fun.call_id,
"type" => "function",
"function" => %{
"name" => fun.name,
"arguments" => fun.arguments
}
}
end

# Function support
def for_api(%Function{} = fun) do
%{
"name" => fun.name,
"parameters" => get_parameters(fun)
}
|> Utils.conditionally_add_to_map("description", fun.description)
end

def for_api(%Message{role: :tool, tool_results: tool_results}) when is_list(tool_results) do
Enum.map(tool_results, &for_api/1)
end

def for_api(%ToolResult{content: content}) do
%{
"role" => :tool,
"content" => content
}
end

def for_api(%Message{content: content} = msg) when is_binary(content) do
%{
"role" => msg.role,
"content" => msg.content
}
|> Utils.conditionally_add_to_map("name", msg.name)
end

def for_api(%Message{role: :user, content: content} = msg) when is_list(content) do
%{
"role" => msg.role,
"content" => Enum.map(content, &for_api(&1))
}
|> Utils.conditionally_add_to_map("name", msg.name)
end

defp get_tools_for_api(nil), do: []

defp get_tools_for_api(tools) do
Enum.map(tools, fn %Function{} = function ->
%{"type" => "function", "function" => for_api(function)}
end)
end

defp get_parameters(%Function{parameters: [], parameters_schema: nil} = _fun) do
%{
"type" => "object",
"properties" => %{}
}
end

defp get_parameters(%Function{parameters: [], parameters_schema: schema} = _fun)
when is_map(schema) do
schema
end

defp get_parameters(%Function{parameters: params} = _fun) do
FunctionParam.to_parameters_schema(params)
end

@doc """
Calls the Ollama Chat Completion API struct with configuration, plus
either a simple message or the list of messages to act as the prompt.

**NOTE:** This API as of right now does not support functions. More
information here: https://github.com/jmorganca/ollama/issues/1729

**NOTE:** This function *can* be used directly, but the primary interface
should be through `LangChain.Chains.LLMChain`. The `ChatOllamaAI` module is more focused on
translating the `LangChain` data structures to and from the Ollama API.
Expand All @@ -233,21 +326,20 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
"""

@impl ChatModel
def call(ollama_ai, prompt, functions \\ [])
def call(ollama_ai, prompt, tools \\ [])

def call(%ChatOllamaAI{} = ollama_ai, prompt, functions) when is_binary(prompt) do
def call(%ChatOllamaAI{} = ollama_ai, prompt, tools) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(ollama_ai, messages, functions)
call(ollama_ai, messages, tools)
end

def call(%ChatOllamaAI{} = ollama_ai, messages, functions)
when is_list(messages) do
def call(%ChatOllamaAI{} = ollama_ai, messages, tools) when is_list(messages) do
try do
case do_api_request(ollama_ai, messages, functions) do
case __MODULE__.do_api_request(ollama_ai, messages, tools) do
{:error, reason} ->
{:error, reason}

Expand Down Expand Up @@ -275,24 +367,24 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
#
# Retries the request up to 3 times on transient errors with a 1 second delay
@doc false
@spec do_api_request(t(), [Message.t()], [Function.t()]) ::
@spec do_api_request(t(), [Message.t()], ChatModel.tools(), integer()) ::
list() | struct() | {:error, String.t()}
def do_api_request(ollama_ai, messages, functions, retry_count \\ 3)
def do_api_request(ollama_ai, messages, tools, retry_count \\ 3)

def do_api_request(_ollama_ai, _messages, _functions, 0) do
def do_api_request(_ollama_ai, _messages, _tools, 0) do
raise LangChainError, "Retries exceeded. Connection failed."
end

def do_api_request(
%ChatOllamaAI{stream: false} = ollama_ai,
messages,
functions,
tools,
retry_count
) do
req =
Req.new(
url: ollama_ai.endpoint,
json: for_api(ollama_ai, messages, functions),
json: for_api(ollama_ai, messages, tools),
receive_timeout: ollama_ai.receive_timeout,
retry: :transient,
max_retries: 3,
Expand All @@ -317,7 +409,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(ollama_ai, messages, functions, retry_count - 1)
do_api_request(ollama_ai, messages, tools, retry_count - 1)

other ->
Logger.error("Unexpected and unhandled API response! #{inspect(other)}")
Expand All @@ -328,12 +420,12 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
def do_api_request(
%ChatOllamaAI{stream: true} = ollama_ai,
messages,
functions,
tools,
retry_count
) do
Req.new(
url: ollama_ai.endpoint,
json: for_api(ollama_ai, messages, functions),
json: for_api(ollama_ai, messages, tools),
receive_timeout: ollama_ai.receive_timeout
)
|> Req.post(
Expand All @@ -357,7 +449,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(ollama_ai, messages, functions, retry_count - 1)
do_api_request(ollama_ai, messages, tools, retry_count - 1)

other ->
Logger.error(
Expand All @@ -372,6 +464,18 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
create_message(message, :complete, MessageDelta)
end

def do_process_response(model, %{
"message" => %{"tool_calls" => calls} = message,
"done" => true
})
when calls != [] do
message
|> Map.merge(%{
"tool_calls" => Enum.map(calls, &do_process_response(model, &1))
})
|> create_message(:complete, Message)
end

def do_process_response(_model, %{"message" => message, "done" => true}) do
create_message(message, :complete, Message)
end
Expand All @@ -385,6 +489,28 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
{:error, reason}
end

def do_process_response(_model, %{
"function" => %{
"arguments" => args,
"name" => name
}
}) do
case ToolCall.new(%{
call_id: Ecto.UUID.generate(),
type: :function,
name: name,
arguments: args
}) do
{:ok, %ToolCall{} = call} ->
call

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to process ToolCall for a function. Reason: #{reason}")
{:error, reason}
end
end

defp create_message(message, status, message_type) do
case message_type.new(Map.merge(message, %{"status" => status})) do
{:ok, new_message} ->
Expand Down
Loading
Loading