-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interrupting completion stream #126
Comments
How to cancel depends on your application and how it is setup. For a LiveView example where the LLMChain is run in an async process, it can be as easy as killing the process. Then there's some cleanup that may need to be done the on partially completed message. The cleanup will depend on your app. In my case, I close out the pending message with a cancelled status but leave it in the chain. Here's an example # cancel the async process
def handle_event("cancel", _params, socket) do
socket =
socket
|> cancel_async(:running_llm)
|> assign(:async_result, %AsyncResult{})
|> put_flash(:info, "Cancelled")
|> close_pending_as_cancelled()
{:noreply, socket}
end
# Close out any pending delta messages as cancelled and save what we've
# received so far. This works when we initiate a cancel or we receive an error
# from the async function.
defp close_pending_as_cancelled(socket) do
chain = socket.assigns.llm_chain
# the task exited with an incomplete delta
if chain.delta != nil do
# most likely was cancelled. An incomplete
# delta can be converted to a "cancelled" message
updated_chain = LLMChain.cancel_delta(chain, :cancelled)
# save the cancelled message
Messages.create_message(
socket.assigns.conversation.id,
Map.from_struct(updated_chain.last_message)
)
socket
|> assign(:llm_chain, updated_chain)
|> assign_messages()
else
socket
end
end I wrote up the Phoenix side of cancelling an async operation here: https://fly.io/phoenix-files/abusing-liveview-new-async-assigns-feature/ Yes, this deserves some documentation or a guide. |
Solved in the next message. Preserved for historical purposes.
Hi @brainlid! I finally had some time to experiment with early cancellation and created a minimal example. (Apologies for the bulkiness; I couldn't think of a more concise version). My goal is to receive an intermediate aggregated message that was accumulated during the generation stage before the process was terminated. I was able to emulate process cancellation, but I have no idea how to solve the last piece: receiving the accumulated message. defmodule LLMChainTest do
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias LangChain.MessageDelta
# Build an LLMChain with the given chat model, messages, and functions
defp build_llm_chain(chat_model, messages, functions) do
%{llm: chat_model, verbose: false}
|> LLMChain.new!()
|> LLMChain.add_messages(messages)
|> LLMChain.add_functions(functions)
end
# Run the LLMChain with a provided callback function
defp run_llm_chain(chain, callback) do
chain
|> LLMChain.run(callback_fn: callback)
end
# Spawn an LLM request as a separate process
def spawn_llm(chat_model, messages, functions, callback) do
IO.puts("spawn_llm:enter")
parent = self()
try do
# Build the LLMChain
chain = build_llm_chain(chat_model, messages, functions)
# Start an asynchronous task to run the LLM chain
task =
Task.async(fn ->
result =
try do
run_llm_chain(chain, callback)
catch
kind, reason -> {:error, {kind, reason}}
end
# Send the result back to the parent process
send(parent, {:llm_request_success, result})
result
end)
IO.inspect(task, label: "spawn_llm:Task")
# Wait for the LLM result or cancellation
wait_for_llm_result(task, chain)
catch
err ->
IO.puts("spawn_llm:catch:err: #{inspect(err)}")
end
end
# Wait for the LLM result or handle cancellation
defp wait_for_llm_result(task, chain) do
receive do
{:cancel} ->
# If a cancellation message is received, shut down the task
IO.puts("wait_for_llm_result:receive:cancel")
Task.shutdown(task, :brutal_kill)
############## THE PROBLEM ##############
# Here, I want to receive an intermediate message,
# everything that was generated before the LangChain task was terminated
############## THE PROBLEM END ##############
{:cancelled, nil}
{:llm_request_success, result} ->
# If the LLM request is successful, process the result
IO.puts("wait_for_llm_result:receive:llm_request_success")
case result do
{:ok, _chain, response} -> {:ok, response}
err -> err
end
other ->
# If any other message is received, continue waiting
IO.inspect(other, label: "wait_for_llm_result:receive:other")
wait_for_llm_result(task, chain)
end
end
# Test function to demonstrate the usage
def run_test do
# Create a new ChatOpenAI model instance
chat_model = LangChain.ChatModels.ChatOpenAI.new!(%{api_key: "YOUR_KEY_HERE", stream: true})
# Define the messages for the LLM
messages = [
LangChain.Message.new_user!("count from 1 to 50")
]
# No functions are used in this example
functions = []
# Define the callback function to handle streaming responses
callback = fn
%MessageDelta{} = data ->
# Handle streaming data
IO.puts(data.content)
%Message{} = data ->
# Handle the completed message
IO.puts("\n\n")
IO.inspect(data.content, label: "COMPLETED MESSAGE")
end
# Set up a cancellation task that will trigger after 1.5 seconds
parent = self()
Task.async(fn ->
Process.sleep(1500)
send(parent, {:cancel})
end)
# Spawn the LLM request and wait for the result
result = spawn_llm(chat_model, messages, functions, callback)
IO.inspect(result, label: "Result")
end
end
# Run the test
LLMChainTest.run_test() When run the code, you will receive output:
Please review the section of the code labeled The desired output should look like this:
Any help would be highly appreciated! Thank you :) |
I believe the best solution is to store deltas in a separate state. To achieve this, I created a GenServer that executes the Chain, stores the intermediate deltas, and, in case of cancellation, combines them and sends a message with the status "canceled." The code below is for demonstration purposes. However, I think it could be beneficial to have a GenServer wrapper for the LLMChain to interact with it more straightforwardly. We can start a GenServer per user, request, session, or any other relevant case and terminate it when necessary. @brainlid If you find this valuable, I can work on a PR for it. For now, if anyone is interested, below is a minimal solution for handling cancellations and receiving partially generated messages. Usage:chat_model =
LangChain.ChatModels.ChatOpenAI.new!(%{api_key: "XXX", stream: true})
# Define the messages for the LLM
messages = [
LangChain.Message.new_user!("count from 1 to 20")
]
# No functions are used in this example
tools = []
# Define the callback function to handle streaming responses
callback = fn
{:message_delta, %LangChain.MessageDelta{} = data, seq_num} ->
# Handle streaming data, seq_num is a sequential number of the delta
IO.inspect(data, label: "CB:DELTA (#{seq_num}): ")
{:message_full, %LangChain.Message{} = data} ->
# Handle the completed message
IO.inspect(data.content, label: "CB:COMPLETED (status: #{data.status}) -> content")
{:error, error} ->
IO.inspect(error, label: "CB:ERROR")
end
# Generate a unique run_id for the LLM request
run_id = UUID.generate()
# Set up a cancellation task that will trigger after 1.5 seconds
Task.start(fn ->
Process.sleep(1500)
ChainServer.cancel(run_id) # Cancel execution with run_id
end)
# Initiate a message generation (synchronous request)
# All events will be passed to the callback, including messages with status :canceled
{:ok, message} = ChainServer.execute_once(
run_id, chat_model, tools, callback, messages
) The code:Show the code.
defmodule ChainServer do
use GenServer
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias LangChain.MessageDelta
# Starting the GenServer and initializing its state
def start_link(%{run_id: run_id, chat_model: _, tools: _, callback: _} = args) do
GenServer.start_link(__MODULE__, args, name: get_pid(run_id))
end
@impl true
def init(%{chat_model: _, tools: _, callback: _} = args) do
state = Map.merge(
args,
%{deltas: [], seq_num: 0, is_canceled: false, caller: nil}
)
{:ok, state}
end
def execute_once(run_id, chat_model, tools, callback, messages) do
task =
Task.async(fn ->
{:ok, pid} = ChainServer.start_link(%{
run_id: run_id,
chat_model: chat_model,
tools: tools,
callback: callback})
resp = ChainServer.execute(run_id, messages)
Process.exit(pid, :kill)
resp
end)
Task.await(task, :infinity)
end
# Public API to start the LLM request with a `run_id`
def execute(run_id, messages) do
GenServer.call(get_pid(run_id), {:run_llm, messages})
end
# Public API to request cancellation with a `run_id`
def cancel(run_id) do
GenServer.cast(get_pid(run_id), :cancel)
end
@impl true
def handle_call({:run_llm, messages}, caller, %{chat_model: chat_model, tools: tools} = state) do
parent = self()
Task.async(fn ->
langchain_callback = fn
%MessageDelta{} = data ->
GenServer.cast(parent, {:receive_message_delta, data})
%Message{} = _data ->
# GenServer.cast(parent, {:receive_message_full, data})
# we will handle this part with `result` below
:ok
end
result =
try do
run_llm_chain(
build_llm_chain(chat_model, messages, tools),
langchain_callback
)
catch
kind, reason ->
{:error, {kind, reason}}
end
# Notify the GenServer
GenServer.cast(parent, {:llm_request_result, result})
:ok
end)
{:noreply, %{state | caller: caller}}
end
@impl true
def handle_cast({:receive_message_delta, message_delta}, %{callback: callback, is_canceled: is_canceled, seq_num: seq_num} = state) do
if not is_canceled do
callback.({:message_delta, message_delta, seq_num})
state =
state
|> Map.put(:seq_num, seq_num + 1)
|> Map.put(:deltas, [message_delta.content | state.deltas])
{:noreply, state}
else
# when is_canceled==true we don't want to change the state
{:noreply, state}
end
end
@impl true
def handle_cast(:cancel, state) do
GenServer.cast(self(), :cancel_collect_results)
# set is_canceled=true to prevent future messages to be handled
{:noreply, %{state | is_canceled: true}}
end
def handle_cast(:cancel_collect_results, %{is_canceled: true, deltas: deltas} = state) do
# Reverse the list to get it in the original order before joining
message_content =
deltas
|> Enum.reverse()
|> IO.iodata_to_binary
message = %Message{
content: message_content,
index: 0,
status: :canceled,
role: :assistant
}
GenServer.cast(self(), {:llm_request_result, {:ok, nil, message}})
{:noreply, state}
end
@impl true
def handle_cast({:llm_request_result, result}, %{caller: caller, callback: callback} = state) do
# send results to callback
case result do
{:ok, _, message} ->
callback.({:message_full, message})
# reply with results to ChainServer.run
GenServer.reply(caller, {:ok, message})
{:error, error} ->
callback.({:error, error})
GenServer.reply(caller, {:error, error})
end
{:noreply, state}
end
defp build_llm_chain(chat_model, messages, tools) do
%{llm: chat_model, verbose: false}
|> LLMChain.new!()
|> LLMChain.add_messages(messages)
|> LLMChain.add_tools(tools)
end
defp run_llm_chain(chain, callback_fn) do
IO.puts("run_llm_chain")
chain
|> LLMChain.run(callback_fn: callback_fn)
end
defp get_pid(run_id) do
{:via, Registry, {LLMChainRegistry, run_id}}
end
end
defmodule LLMChainTest do
alias Ecto.UUID
def run_test do
# Create a new ChatOpenAI model instance
chat_model =
LangChain.ChatModels.ChatOpenAI.new!(%{
api_key: "XXX",
stream: true
})
# Define the messages for the LLM
messages = [
LangChain.Message.new_user!("count from 1 to 20")
]
# No functions are used in this example
tools = []
# Define the callback function to handle streaming responses
callback = fn
{:message_delta, %LangChain.MessageDelta{} = data, seq_num} ->
# Handle streaming data
IO.inspect(data, label: "CB:DELTA (#{seq_num}): ")
{:message_full, %LangChain.Message{} = data} ->
# Handle the completed message
IO.puts("\n")
IO.inspect(data.content, label: "CB:COMPLETED (status: #{data.status}) -> content")
{:error, error} ->
IO.inspect(error, label: "CB:ERROR")
end
# Start the GenServer and initialize the LLM request
run_id = UUID.generate()
# Set up a cancellation task that will trigger after 1.5 seconds
Task.start(fn ->
Process.sleep(1500)
ChainServer.cancel(run_id)
end)
result = ChainServer.execute_once(
run_id, chat_model, tools, callback, messages
)
IO.inspect(result, label: "final result")
end
end
# You need to start the Registry before running the test
{:ok, _} = Registry.start_link(keys: :unique, name: LLMChainRegistry)
LLMChainTest.run_test()
Output:
|
The idea of two separate processes is how it works. The demo project uses a Task created process for running the chain, then the other process is the LiveView. Yes, the other process could be a more explicitly managed GenServer. Thanks for sharing your work! |
Is there a way to interrupt the generation stream? It's technically possible, but I haven't found any mention in the docs.
It can be useful for user-facing frontends when a user can abort the answer of the assistant in the middle and rephrase the task.
OpenAI forum: https://community.openai.com/t/interrupting-completion-stream-in-python/30628
The text was updated successfully, but these errors were encountered: