From 782f1d8980110f130556c24ae7048a6f978b502b Mon Sep 17 00:00:00 2001 From: Mark Ericksen Date: Tue, 26 Nov 2024 16:11:22 -0700 Subject: [PATCH] cancel a message delta when received "overloaded" error - in apply_delta, the error can be returned as a delta because that's how it was received --- lib/chains/llm_chain.ex | 7 ++++++- test/chains/llm_chain_test.exs | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/lib/chains/llm_chain.ex b/lib/chains/llm_chain.ex index b1bb516..d57ed9b 100644 --- a/lib/chains/llm_chain.ex +++ b/lib/chains/llm_chain.ex @@ -454,7 +454,7 @@ defmodule LangChain.Chains.LLMChain do completes the message, the LLMChain is updated to clear the `delta` and the `last_message` and list of messages are updated. """ - @spec apply_delta(t(), MessageDelta.t()) :: t() + @spec apply_delta(t(), MessageDelta.t() | {:error, LangChainError.t()}) :: t() def apply_delta(%LLMChain{delta: nil} = chain, %MessageDelta{} = new_delta) do %LLMChain{chain | delta: new_delta} end @@ -464,6 +464,11 @@ defmodule LangChain.Chains.LLMChain do delta_to_message_when_complete(%LLMChain{chain | delta: merged}) end + # Handle when the server is overloaded and cancelled the stream on the server side. + def apply_delta(%LLMChain{} = chain, {:error, %LangChainError{type: "overloaded"}}) do + cancel_delta(chain, :cancelled) + end + @doc """ Convert any hanging delta of the chain to a message and append to the chain. diff --git a/test/chains/llm_chain_test.exs b/test/chains/llm_chain_test.exs index e0b4092..53b0d6c 100644 --- a/test/chains/llm_chain_test.exs +++ b/test/chains/llm_chain_test.exs @@ -472,6 +472,27 @@ defmodule LangChain.Chains.LLMChainTest do assert tool_call.arguments == %{"expression" => "100 + 300 - 200"} assert updated_chain.messages == [last] end + + test "cancels the current delta when applying an overloaded error", %{chain: chain} do + assert chain.messages == [] + + updated_chain = + chain + |> LLMChain.apply_delta( + MessageDelta.new!(%{role: :assistant, content: "Greetings from "}) + ) + |> LLMChain.apply_delta(MessageDelta.new!(%{content: "your "})) + |> LLMChain.apply_delta(MessageDelta.new!(%{content: "favorite "})) + |> LLMChain.apply_delta({:error, LangChainError.exception(type: "overloaded", message: "Overloaded")}) + + # the delta is complete and removed from the chain + assert updated_chain.delta == nil + # the delta is converted to a message and applied to the messages + assert [%Message{} = new_message] = updated_chain.messages + assert new_message.role == :assistant + assert new_message.content == "Greetings from your favorite " + assert new_message.status == :cancelled + end end describe "apply_deltas/2" do