Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement initial support for fallbacks #207

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 173 additions & 14 deletions lib/chains/llm_chain.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,83 @@ defmodule LangChain.Chains.LLMChain do
|> LLMChain.run()

In the LiveView, a `handle_info` function executes with the received message.

## Fallbacks

When running a chain, the `:with_fallbacks` option can be used to provide a
list of fallback chat models to try when a failure is encountered.

When working with language models, you may often encounter issues from the
underlying APIs, whether these be rate limiting, downtime, or something else.
Therefore, as you go to move your LLM applications into production it becomes
more and more important to safeguard against these. That's what fallbacks are
designed to provide.

A **fallback** is an alternative plan that may be used in an emergency.

A `before_fallback` function can be provided to alter or return a different
chain to use with the fallback LLM model. This is important because often, the
prompts needed for will differ for a fallback LLM. This means if your OpenAI
completion fails, a different prompt may be needed when retrying it with an
Anthropic fallback.

### Fallback for LLM API Errors

This is perhaps the most common use case for fallbacks. A request to an LLM
API can fail for a variety of reasons - the API could be down, you could have
hit rate limits, any number of things. Therefore, using fallbacks can help
protect against these types of failures.

## Fallback Examples

A simple fallback that tries a different LLM chat model

fallback_llm = ChatAnthropic.new!(%{stream: false})

{:ok, updated_chain} =
%{llm: ChatOpenAI.new!(%{stream: false})}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!("OpenAI system prompt"))
|> LLMChain.add_message(Message.new_user!("Why is the sky blue?"))
|> LLMChain.run(with_fallbacks: [fallback_llm])

Note the `with_fallbacks: [fallback_llm]` option when running the chain.

This example uses the `:before_fallback` option to provide a function that can
modify or return an alternate chain when used with a certain LLM. Also note
the utility function `LangChain.Utils.replace_system_message!/2` is used for
swapping out the system message when falling back to a different LLM.

fallback_llm = ChatAnthropic.new!(%{stream: false})

{:ok, updated_chain} =
%{llm: ChatOpenAI.new!(%{stream: false})}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!("OpenAI system prompt"))
|> LLMChain.add_message(Message.new_user!("Why is the sky blue?"))
|> LLMChain.run(
with_fallbacks: [fallback_llm],
before_fallback: fn chain ->
case chain.llm do
%ChatAnthropic{} ->
# replace the system message
%LLMChain{
chain
| messages:
Utils.replace_system_message!(
chain.messages,
Message.new_system!("Anthropic system prompt")
)
}

_open_ai ->
chain
end
end
)

See `LangChain.Chains.LLMChain.run/2` for more details.

"""
use Ecto.Schema
import Ecto.Changeset
Expand Down Expand Up @@ -93,7 +170,11 @@ defmodule LangChain.Chains.LLMChain do
# Track the last `%Message{}` received in the chain.
field :last_message, :any, virtual: true
# Internally managed. The list of exchanged messages during a `run` function
# execution.
# execution. A single run can result in a number of newly created messages.
# It generates an Assistant message with one or more ToolCalls, the message
# with tool results where some of them may have failed requiring the LLM to
# try again. This list tracks the full set of exchanged messages during a
# single run.
field :exchanged_messages, {:array, :any}, default: [], virtual: true
# Track if the state of the chain expects a response from the LLM. This
# happens after sending a user message, when a tool_call is received, or
Expand Down Expand Up @@ -235,6 +316,21 @@ defmodule LangChain.Chains.LLMChain do
are evaluated, the `ToolResult` messages are returned to the LLM giving it
an opportunity to use the `ToolResult` information in an assistant response
message. In essence, this mode always gives the LLM the last word.

- `with_fallbacks: [...]` - Provide a list of chat models to use as a fallback
when one fails. This helps a production system remain operational when an
API limit is reached, an LLM service is overloaded or down, or something
else new an exciting goes wrong.

When all fallbacks fail, a `%LangChainError{type: "all_fallbacks_failed"}`
is returned in the error response.

- `before_fallback: fn chain -> modified_chain end` - A `before_fallback`
function is called before the LLM call is made. **NOTE: When provided, it
also fires for the first attempt.** This allows a chain to be modified or
replaced before running against the configured LLM. This is helpful, for
example, when a different system prompt is needed for Anthropic vs OpenAI.

"""
@spec run(t(), Keyword.t()) :: {:ok, t()} | {:error, t(), LangChainError.t()}
def run(chain, opts \\ [])
Expand All @@ -253,22 +349,85 @@ defmodule LangChain.Chains.LLMChain do
# clear the set of exchanged messages.
chain = clear_exchanged_messages(chain)

case Keyword.get(opts, :mode, nil) do
nil ->
# run the chain and format the return
case do_run(chain) do
{:ok, chain} ->
{:ok, chain}
# determine which function to run based on the mode.
function_to_run =
case Keyword.get(opts, :mode, nil) do
nil ->
&do_run/1

{:error, _chain, _reason} = error ->
error
end
:while_needs_response ->
&run_while_needs_response/1

:until_success ->
&run_until_success/1
end

# Run the chain and return the success or error results. NOTE: We do not add
# the current LLM to the list and process everything through a single
# codepath because failing after attempted fallbacks returns a different
# error.
if Keyword.has_key?(opts, :with_fallbacks) do
# run function and using fallbacks as needed.
with_fallbacks(chain, opts, function_to_run)
else
# run it directly right now and return the success or error
function_to_run.(chain)
end
end

defp with_fallbacks(%LLMChain{} = chain, opts, run_fn) do
# Sources of inspiration:
# - https://python.langchain.com/v0.1/docs/guides/productionization/fallbacks/
# - https://python.langchain.com/docs/how_to/fallbacks/
# - https://python.langchain.com/docs/how_to/fallbacks/

llm_list = Keyword.fetch!(opts, :with_fallbacks)
before_fallback_fn = Keyword.get(opts, :before_fallback, nil)

:while_needs_response ->
run_while_needs_response(chain)
# try the chain where we go through the full list of LLMs to try. Add the
# current LLM as the first so all are processed the same way.
try_chain_with_llm(chain, [chain.llm | llm_list], before_fallback_fn, run_fn)
end

# nothing left to try
defp try_chain_with_llm(chain, [], _before_fallback_fn, _run_fn) do
{:error, chain,
LangChainError.exception(
type: "all_fallbacks_failed",
message: "Failed all attempts to generate response"
)}
end

defp try_chain_with_llm(chain, [llm | tail], before_fallback_fn, run_fn) do
use_chain = %LLMChain{chain | llm: llm}

use_chain =
if before_fallback_fn do
# use the returned chain from the before_fallback function.
before_fallback_fn.(use_chain)
else
use_chain
end

try do
case run_fn.(use_chain) do
{:ok, result} ->
{:ok, result}

{:error, _error_chain, reason} ->
# run attempt received an error. Try again with the next LLM
Logger.warning("LLM call failed, using next fallback. Reason: #{inspect(reason)}")

try_chain_with_llm(use_chain, tail, before_fallback_fn, run_fn)
end
rescue
err ->
# Log the error and try again.
Logger.error(
"Rescued from exception during with_fallback processing. Error: #{inspect(err)}"
)

:until_success ->
run_until_success(chain)
try_chain_with_llm(use_chain, tail, before_fallback_fn, run_fn)
end
end

Expand Down
2 changes: 1 addition & 1 deletion lib/langchain_error.ex
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ defmodule LangChain.LangChainError do
Create the exception using either a message or a changeset who's errors are
converted to a message.
"""
@spec exception(message :: String.t() | Ecto.Changeset.t()) :: t()
@spec exception(message :: String.t() | Ecto.Changeset.t()) :: t() | no_return()
def exception(message) when is_binary(message) do
%LangChainError{message: message}
end
Expand Down
11 changes: 11 additions & 0 deletions lib/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,15 @@ defmodule LangChain.Utils do

{List.first(system), other}
end

@doc """
Replace the system message with a new system message. This retains all other
messages as-is. An error is raised if there are more than 1 system messages.
"""
@spec replace_system_message!([Message.t()], Message.t()) :: [Message.t()] | no_return()
def replace_system_message!(messages, new_system_message) do
{_old_system, rest} = split_system_message(messages)
# return the new system message along with the rest
[new_system_message | rest]
end
end
107 changes: 106 additions & 1 deletion test/chains/llm_chain_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ defmodule LangChain.Chains.LLMChainTest do
alias LangChain.MessageDelta
alias LangChain.LangChainError
alias LangChain.MessageProcessors.JsonProcessor
alias LangChain.Utils

@anthropic_test_model "claude-3-opus-20240229"

Expand Down Expand Up @@ -483,7 +484,9 @@ defmodule LangChain.Chains.LLMChainTest do
)
|> LLMChain.apply_delta(MessageDelta.new!(%{content: "your "}))
|> LLMChain.apply_delta(MessageDelta.new!(%{content: "favorite "}))
|> LLMChain.apply_delta({:error, LangChainError.exception(type: "overloaded", message: "Overloaded")})
|> LLMChain.apply_delta(
{:error, LangChainError.exception(type: "overloaded", message: "Overloaded")}
)

# the delta is complete and removed from the chain
assert updated_chain.delta == nil
Expand Down Expand Up @@ -1429,6 +1432,108 @@ defmodule LangChain.Chains.LLMChainTest do
assert reason.message == "Exceeded max failure count"
assert updated_chain.current_failure_count == 3
end

test "with_fallbacks: re-runs with next LLM after first fails" do
# Made NOT LIVE here - handles two calls
expect(ChatOpenAI, :call, fn _model, _messages, _tools ->
# IO.puts "FAKE OpenAI ERROR RESULT RETURNED"
{:error,
LangChainError.exception(type: "too_many_requests", message: "Too many requests!")}
end)

expect(ChatAnthropic, :call, fn _model, _messages, _tools ->
{:ok,
[
Message.new_assistant!(%{content: "fallback worked!"})
]}
end)

{:ok, updated_chain} =
%{llm: ChatOpenAI.new!(%{stream: false})}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!())
|> LLMChain.add_message(Message.new_user!("Why is the sky blue?"))
|> LLMChain.run(with_fallbacks: [ChatAnthropic.new!(%{stream: false})])

# stopped after processing a successful assistant response
assert updated_chain.last_message.role == :assistant
assert updated_chain.last_message.content == "fallback worked!"
end

test "with_fallbacks: runs each LLM option and returns when all failed" do
# Made NOT LIVE here - handles two calls
expect(ChatOpenAI, :call, fn _model, _messages, _tools ->
# IO.puts "FAKE OpenAI ERROR RESULT RETURNED"
{:error,
LangChainError.exception(type: "too_many_requests", message: "Too many requests!")}
end)

expect(ChatAnthropic, :call, fn _model, _messages, _tools ->
{:error, LangChainError.exception(type: "overloaded", message: "Overloaded")}
end)

{:error, _updated_chain, reason} =
%{llm: ChatOpenAI.new!(%{stream: false})}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!())
|> LLMChain.add_message(Message.new_user!("Why is the sky blue?"))
|> LLMChain.run(with_fallbacks: [ChatAnthropic.new!(%{stream: false})])

assert %LangChainError{
type: "all_fallbacks_failed",
message: "Failed all attempts to generate response"
} == reason
end

test "with_fallbacks: runs before_fallback function and uses the resulting chain" do
# Made NOT LIVE here - handles two calls
expect(ChatOpenAI, :call, fn _model, _messages, _tools ->
# IO.puts "FAKE OpenAI ERROR RESULT RETURNED"
{:error,
LangChainError.exception(type: "too_many_requests", message: "Too many requests!")}
end)

expect(ChatAnthropic, :call, fn _model, _messages, _tools ->
{:ok, Message.new_assistant!(%{content: "Claude says it's because it's not red."})}
end)

{:ok, updated_chain} =
%{llm: ChatOpenAI.new!(%{stream: false})}
|> LLMChain.new!()
|> LLMChain.add_message(Message.new_system!("OpenAI system prompt"))
|> LLMChain.add_message(Message.new_user!("Why is the sky blue?"))
|> LLMChain.run(
with_fallbacks: [
ChatAnthropic.new!(%{stream: false})
],
before_fallback: fn chain ->
send(self(), :before_fallback_fired)

case chain.llm do
%ChatAnthropic{} ->
# replace the system message
%LLMChain{
chain
| messages:
Utils.replace_system_message!(
chain.messages,
Message.new_system!("Anthropic system prompt")
)
}

_open_ai ->
chain
end
end
)

assert [system_msg | _rest] = updated_chain.messages
assert system_msg.role == :system
assert system_msg.content == "Anthropic system prompt"
assert updated_chain.last_message.role == :assistant
assert updated_chain.last_message.content == "Claude says it's because it's not red."
assert_received :before_fallback_fired
end
end

describe "increment_current_failure_count/1" do
Expand Down
Loading
Loading