Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interrupting completion stream #126

Open
ksanderer opened this issue May 30, 2024 · 4 comments
Open

Interrupting completion stream #126

ksanderer opened this issue May 30, 2024 · 4 comments

Comments

@ksanderer
Copy link

Is there a way to interrupt the generation stream? It's technically possible, but I haven't found any mention in the docs.

It can be useful for user-facing frontends when a user can abort the answer of the assistant in the middle and rephrase the task.

OpenAI forum: https://community.openai.com/t/interrupting-completion-stream-in-python/30628

Ashton1998
Jun 2023
I make a simple test for @thehunmonkgroup 's solution.

I make a call to gpt-3.5-turbo model with input:

Please introduce GPT model structure as detail as possible
And let the api print all the token’s. The statistic result from OpenAI usage page is (I am a new user and is not allowed to post with >media, so I only copy the result):
17 prompt + 441 completion = 568 tokens

After that, I stop the generation when the number of token received is 9, the result is:
17 prompt + 27 completion = 44 tokens

It seems there are roughly extra 10 tokens generated after I stop the generation.

Then I stop the generation when the number is 100, the result is:
17 prompt + 111 completion = 128 tokens

So I think the solution work well but with extra 10~20 tokens every time.

@brainlid
Copy link
Owner

brainlid commented Jun 1, 2024

How to cancel depends on your application and how it is setup. For a LiveView example where the LLMChain is run in an async process, it can be as easy as killing the process. Then there's some cleanup that may need to be done the on partially completed message. The cleanup will depend on your app. In my case, I close out the pending message with a cancelled status but leave it in the chain.

Here's an example

  # cancel the async process
  def handle_event("cancel", _params, socket) do
    socket =
      socket
      |> cancel_async(:running_llm)
      |> assign(:async_result, %AsyncResult{})
      |> put_flash(:info, "Cancelled")
      |> close_pending_as_cancelled()

    {:noreply, socket}
  end

 # Close out any pending delta messages as cancelled and save what we've
  # received so far. This works when we initiate a cancel or we receive an error
  # from the async function.
  defp close_pending_as_cancelled(socket) do
    chain = socket.assigns.llm_chain

    # the task exited with an incomplete delta
    if chain.delta != nil do
      # most likely was cancelled. An incomplete
      # delta can be converted to a "cancelled" message
      updated_chain = LLMChain.cancel_delta(chain, :cancelled)

      # save the cancelled message
      Messages.create_message(
        socket.assigns.conversation.id,
        Map.from_struct(updated_chain.last_message)
      )

      socket
      |> assign(:llm_chain, updated_chain)
      |> assign_messages()
    else
      socket
    end
  end

I wrote up the Phoenix side of cancelling an async operation here: https://fly.io/phoenix-files/abusing-liveview-new-async-assigns-feature/

Yes, this deserves some documentation or a guide.

@ksanderer
Copy link
Author

ksanderer commented Jul 16, 2024

Solved in the next message. Preserved for historical purposes.

Hi @brainlid!

I finally had some time to experiment with early cancellation and created a minimal example. (Apologies for the bulkiness; I couldn't think of a more concise version).

My goal is to receive an intermediate aggregated message that was accumulated during the generation stage before the process was terminated.

I was able to emulate process cancellation, but I have no idea how to solve the last piece: receiving the accumulated message.

defmodule LLMChainTest do
  alias LangChain.Chains.LLMChain
  alias LangChain.Message
  alias LangChain.MessageDelta

  # Build an LLMChain with the given chat model, messages, and functions
  defp build_llm_chain(chat_model, messages, functions) do
    %{llm: chat_model, verbose: false}
    |> LLMChain.new!()
    |> LLMChain.add_messages(messages)
    |> LLMChain.add_functions(functions)
  end

  # Run the LLMChain with a provided callback function
  defp run_llm_chain(chain, callback) do
    chain
    |> LLMChain.run(callback_fn: callback)
  end

  # Spawn an LLM request as a separate process
  def spawn_llm(chat_model, messages, functions, callback) do
    IO.puts("spawn_llm:enter")
    parent = self()

    try do
      # Build the LLMChain
      chain = build_llm_chain(chat_model, messages, functions)

      # Start an asynchronous task to run the LLM chain
      task =
        Task.async(fn ->
          result =
            try do
              run_llm_chain(chain, callback)
            catch
              kind, reason -> {:error, {kind, reason}}
            end

          # Send the result back to the parent process
          send(parent, {:llm_request_success, result})
          result
        end)

      IO.inspect(task, label: "spawn_llm:Task")

      # Wait for the LLM result or cancellation
      wait_for_llm_result(task, chain)
    catch
      err ->
        IO.puts("spawn_llm:catch:err: #{inspect(err)}")
    end
  end

  # Wait for the LLM result or handle cancellation
  defp wait_for_llm_result(task, chain) do
    receive do
      {:cancel} ->
        # If a cancellation message is received, shut down the task
        IO.puts("wait_for_llm_result:receive:cancel")
        Task.shutdown(task, :brutal_kill)
        
        ############## THE PROBLEM ##############
        # Here, I want to receive an intermediate message, 
        # everything that was generated before the LangChain task was terminated
        ############## THE PROBLEM END ##############

        {:cancelled, nil}

      {:llm_request_success, result} ->
        # If the LLM request is successful, process the result
        IO.puts("wait_for_llm_result:receive:llm_request_success")
        case result do
          {:ok, _chain, response} -> {:ok, response}
          err -> err
        end

      other ->
        # If any other message is received, continue waiting
        IO.inspect(other, label: "wait_for_llm_result:receive:other")
        wait_for_llm_result(task, chain)
    end
  end

  # Test function to demonstrate the usage
  def run_test do
    # Create a new ChatOpenAI model instance
    chat_model = LangChain.ChatModels.ChatOpenAI.new!(%{api_key: "YOUR_KEY_HERE", stream: true})

    # Define the messages for the LLM
    messages = [
      LangChain.Message.new_user!("count from 1 to 50")
    ]

    # No functions are used in this example
    functions = []

    # Define the callback function to handle streaming responses
    callback = fn
      %MessageDelta{} = data ->
        # Handle streaming data
        IO.puts(data.content)

      %Message{} = data ->
        # Handle the completed message
        IO.puts("\n\n")
        IO.inspect(data.content, label: "COMPLETED MESSAGE")
    end

    # Set up a cancellation task that will trigger after 1.5 seconds
    parent = self()
    Task.async(fn ->
      Process.sleep(1500)
      send(parent, {:cancel})
    end)

    # Spawn the LLM request and wait for the result
    result = spawn_llm(chat_model, messages, functions, callback)
    IO.inspect(result, label: "Result")
  end
end

# Run the test
LLMChainTest.run_test()

When run the code, you will receive output:

spawn_llm:enter
spawn_llm:Task: %Task{
  mfa: {:erlang, :apply, 2},
  owner: #PID<0.99.0>,
  pid: #PID<0.577.0>,
  ref: #Reference<0.0.12675.4275237201.2583232516.136432>
}

1
,
 
2
,
 
3
,
 
4
,
 
5
,
wait_for_llm_result:receive:cancel
Result: {:cancelled, nil}

Please review the section of the code labeled THE PROBLEM. I am unsure how to extract the chain as it existed prior to the process termination.

The desired output should look like this:

spawn_llm:enter
spawn_llm:Task: %Task{
  mfa: {:erlang, :apply, 2},
  owner: #PID<0.99.0>,
  pid: #PID<0.577.0>,
  ref: #Reference<0.0.12675.4275237201.2583232516.136432>
}

1
,
 
2
,
 
3
,
 
4
,
 
5
,
wait_for_llm_result:receive:cancel
COMPLETED MESSAGE: "1, 2, 3, 4, 5,"
Result: {:canceled,
 %LangChain.Message{
   content: "1, 2, 3, 4, 5,",
   index: 0,
   status: :canceled,
   role: :assistant,
   function_name: nil,
   arguments: nil
 }}

Any help would be highly appreciated! Thank you :)

@ksanderer
Copy link
Author

ksanderer commented Jul 26, 2024

I believe the best solution is to store deltas in a separate state. To achieve this, I created a GenServer that executes the Chain, stores the intermediate deltas, and, in case of cancellation, combines them and sends a message with the status "canceled."

The code below is for demonstration purposes. However, I think it could be beneficial to have a GenServer wrapper for the LLMChain to interact with it more straightforwardly. We can start a GenServer per user, request, session, or any other relevant case and terminate it when necessary.

@brainlid If you find this valuable, I can work on a PR for it.

For now, if anyone is interested, below is a minimal solution for handling cancellations and receiving partially generated messages.

Usage:

chat_model =
  LangChain.ChatModels.ChatOpenAI.new!(%{api_key: "XXX", stream: true})

# Define the messages for the LLM
messages = [
  LangChain.Message.new_user!("count from 1 to 20")
]

# No functions are used in this example
tools = []

# Define the callback function to handle streaming responses
callback = fn
  {:message_delta, %LangChain.MessageDelta{} = data, seq_num} ->
    # Handle streaming data, seq_num is a sequential number of the delta
    IO.inspect(data, label: "CB:DELTA (#{seq_num}): ")

  {:message_full, %LangChain.Message{} = data} ->
    # Handle the completed message
    IO.inspect(data.content, label: "CB:COMPLETED (status: #{data.status}) -> content")


  {:error, error} ->
    IO.inspect(error, label: "CB:ERROR")
end

# Generate a unique run_id for the LLM request
run_id = UUID.generate()

# Set up a cancellation task that will trigger after 1.5 seconds
Task.start(fn ->
  Process.sleep(1500)
  ChainServer.cancel(run_id) # Cancel execution with run_id
end)

# Initiate a message generation (synchronous request)
# All events will be passed to the callback, including messages with status :canceled
{:ok, message} = ChainServer.execute_once(
  run_id, chat_model, tools, callback, messages
)

The code:

Show the code.

defmodule ChainServer do
  use GenServer

  alias LangChain.Chains.LLMChain
  alias LangChain.Message
  alias LangChain.MessageDelta

  # Starting the GenServer and initializing its state
  def start_link(%{run_id: run_id, chat_model: _, tools: _, callback: _} = args) do
    GenServer.start_link(__MODULE__, args, name: get_pid(run_id))
  end

  @impl true
  def init(%{chat_model: _, tools: _, callback: _} = args) do
    state = Map.merge(
      args,
      %{deltas: [], seq_num: 0, is_canceled: false, caller: nil}
    )
    {:ok, state}
  end

  def execute_once(run_id, chat_model, tools, callback, messages) do
    task =
      Task.async(fn ->
        {:ok, pid} = ChainServer.start_link(%{
          run_id: run_id,
          chat_model: chat_model,
          tools: tools,
          callback: callback})
        resp = ChainServer.execute(run_id, messages)
        Process.exit(pid, :kill)
        resp
      end)
    Task.await(task, :infinity)
  end

  # Public API to start the LLM request with a `run_id`
  def execute(run_id, messages) do
    GenServer.call(get_pid(run_id), {:run_llm, messages})
  end

  # Public API to request cancellation with a `run_id`
  def cancel(run_id) do
    GenServer.cast(get_pid(run_id), :cancel)
  end

  @impl true
  def handle_call({:run_llm, messages}, caller, %{chat_model: chat_model, tools: tools} = state) do
    parent = self()
    Task.async(fn ->
      langchain_callback = fn
        %MessageDelta{} = data ->
          GenServer.cast(parent, {:receive_message_delta, data})

        %Message{} = _data ->
          # GenServer.cast(parent, {:receive_message_full, data})
          # we will handle this part with `result` below
          :ok
      end

      result =
        try do
          run_llm_chain(
            build_llm_chain(chat_model, messages, tools),
            langchain_callback
          )
        catch
          kind, reason ->
            {:error, {kind, reason}}
        end

      # Notify the GenServer
      GenServer.cast(parent, {:llm_request_result, result})
      :ok
    end)

    {:noreply, %{state | caller: caller}}
  end

  @impl true
  def handle_cast({:receive_message_delta, message_delta}, %{callback: callback, is_canceled: is_canceled, seq_num: seq_num} = state) do
    if not is_canceled do
      callback.({:message_delta, message_delta, seq_num})
      state =
        state
        |> Map.put(:seq_num, seq_num + 1)
        |> Map.put(:deltas, [message_delta.content | state.deltas])
      {:noreply, state}
    else
      # when is_canceled==true we don't want to change the state
      {:noreply, state}
    end
  end

  @impl true
  def handle_cast(:cancel, state) do
    GenServer.cast(self(), :cancel_collect_results)
    # set is_canceled=true to prevent future messages to be handled
    {:noreply, %{state | is_canceled: true}}
  end

  def handle_cast(:cancel_collect_results, %{is_canceled: true, deltas: deltas} = state) do
    # Reverse the list to get it in the original order before joining
    message_content =
      deltas
      |> Enum.reverse()
      |> IO.iodata_to_binary

    message = %Message{
      content: message_content,
      index: 0,
      status: :canceled,
      role: :assistant
    }

    GenServer.cast(self(), {:llm_request_result, {:ok, nil, message}})
    {:noreply, state}
  end

  @impl true
  def handle_cast({:llm_request_result, result}, %{caller: caller, callback: callback} = state) do
    # send results to callback
    case result do
      {:ok, _, message} ->
        callback.({:message_full, message})
        # reply with results to ChainServer.run
        GenServer.reply(caller, {:ok, message})
      {:error, error} ->
        callback.({:error, error})
        GenServer.reply(caller, {:error, error})
    end

    {:noreply, state}
  end

  defp build_llm_chain(chat_model, messages, tools) do
    %{llm: chat_model, verbose: false}
    |> LLMChain.new!()
    |> LLMChain.add_messages(messages)
    |> LLMChain.add_tools(tools)
  end

  defp run_llm_chain(chain, callback_fn) do
    IO.puts("run_llm_chain")
    chain
    |> LLMChain.run(callback_fn: callback_fn)
  end

  defp get_pid(run_id) do
    {:via, Registry, {LLMChainRegistry, run_id}}
  end
end

defmodule LLMChainTest do
  alias Ecto.UUID
  def run_test do
    # Create a new ChatOpenAI model instance
    chat_model =
      LangChain.ChatModels.ChatOpenAI.new!(%{
        api_key: "XXX",
        stream: true
      })

    # Define the messages for the LLM
    messages = [
      LangChain.Message.new_user!("count from 1 to 20")
    ]

    # No functions are used in this example
    tools = []

    # Define the callback function to handle streaming responses
    callback = fn
      {:message_delta, %LangChain.MessageDelta{} = data, seq_num} ->
        # Handle streaming data
        IO.inspect(data, label: "CB:DELTA (#{seq_num}): ")

      {:message_full, %LangChain.Message{} = data} ->
        # Handle the completed message
        IO.puts("\n")
        IO.inspect(data.content, label: "CB:COMPLETED (status: #{data.status}) -> content")


      {:error, error} ->
        IO.inspect(error, label: "CB:ERROR")
    end

    # Start the GenServer and initialize the LLM request
    run_id = UUID.generate()

    # Set up a cancellation task that will trigger after 1.5 seconds
    Task.start(fn ->
      Process.sleep(1500)
      ChainServer.cancel(run_id)
    end)

    result = ChainServer.execute_once(
      run_id, chat_model, tools, callback, messages
    )

    IO.inspect(result, label: "final result")
  end
end

# You need to start the Registry before running the test
{:ok, _} = Registry.start_link(keys: :unique, name: LLMChainRegistry)

LLMChainTest.run_test()

Output:

CB:DELTA (0): : %LangChain.MessageDelta{
  content: "",
  status: :incomplete,
  index: 0,
  role: :assistant,
  tool_calls: nil
}

...

CB:DELTA (43): : %LangChain.MessageDelta{
  content: "15",
  status: :incomplete,
  index: 0,
  role: :unknown,
  tool_calls: nil
}
CB:DELTA (44): : %LangChain.MessageDelta{
  content: ",",
  status: :incomplete,
  index: 0,
  role: :unknown,
  tool_calls: nil
}
CB:DELTA (45): : %LangChain.MessageDelta{
  content: " ",
  status: :incomplete,
  index: 0,
  role: :unknown,
  tool_calls: nil
}


CB:COMPLETED (status: canceled) -> content: "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, "
final result: {:ok,
 %LangChain.Message{
   content: "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, ",
   index: 0,
   status: :canceled,
   role: :assistant,
   name: nil,
   tool_calls: nil,
   tool_results: nil
 }}

@brainlid
Copy link
Owner

brainlid commented Jul 26, 2024

The idea of two separate processes is how it works. The demo project uses a Task created process for running the chain, then the other process is the LiveView.

Yes, the other process could be a more explicitly managed GenServer.

Thanks for sharing your work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants