Skip to content

Commit

Permalink
support chat history (Quansight#438)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
blakerosenthal and pmeier authored Jul 11, 2024
1 parent 6bdda75 commit 55f7fc5
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 58 deletions.
4 changes: 2 additions & 2 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@


class DemoStreamingAssistant(assistants.RagnaDemoAssistant):
def answer(self, prompt, sources):
content = next(super().answer(prompt, sources))
def answer(self, messages):
content = next(super().answer(messages))
for chunk in content.split(" "):
yield f"{chunk} "

Expand Down
20 changes: 10 additions & 10 deletions docs/tutorials/gallery_custom_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import uuid

from ragna.core import Document, Source, SourceStorage
from ragna.core import Document, Source, SourceStorage, Message


class TutorialSourceStorage(SourceStorage):
Expand Down Expand Up @@ -61,9 +61,9 @@ def retrieve(
# %%
# ### Assistant
#
# [ragna.core.Assistant][]s are objects that take a user prompt and relevant
# [ragna.core.Source][]s and generate a response form that. Usually, assistants are
# LLMs.
# [ragna.core.Assistant][]s are objects that take the chat history as list of
# [ragna.core.Message][]s and their relevant [ragna.core.Source][]s and generate a
# response from that. Usually, assistants are LLMs.
#
# In this tutorial, we define a minimal `TutorialAssistant` that is similar to
# [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user
Expand All @@ -82,8 +82,11 @@ def retrieve(


class TutorialAssistant(Assistant):
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
def answer(self, messages: list[Message]) -> Iterator[str]:
print(f"Running {type(self).__name__}().answer()")
# For simplicity, we only deal with the last message here, i.e. the latest user
# prompt.
prompt, sources = (message := messages[-1]).content, message.sources
yield (
f"To answer the user prompt '{prompt}', "
f"I was given {len(sources)} source(s)."
Expand Down Expand Up @@ -254,8 +257,7 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
class ElaborateTutorialAssistant(Assistant):
def answer(
self,
prompt: str,
sources: list[Source],
messages: list[Message],
*,
my_required_parameter: int,
my_optional_parameter: str = "foo",
Expand Down Expand Up @@ -393,9 +395,7 @@ def answer(


class AsyncAssistant(Assistant):
async def answer(
self, prompt: str, sources: list[Source]
) -> AsyncIterator[str]:
async def answer(self, messages: list[Message]) -> AsyncIterator[str]:
print(f"Running {type(self).__name__}().answer()")
start = time.perf_counter()
await asyncio.sleep(0.3)
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant

Expand All @@ -23,11 +23,12 @@ def _make_system_content(self, sources: list[Source]) -> str:
return instruction + "\n\n".join(source.content for source in sources)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import PackageRequirement, RagnaException, Requirement, Source
from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -37,10 +37,11 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Message, RagnaException, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -25,11 +25,12 @@ def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
prompt, sources = (message := messages[-1]).content, message.sources
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
Expand Down
26 changes: 17 additions & 9 deletions ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import re
import textwrap
from typing import Iterator

from ragna.core import Assistant, Source
from ragna.core import Assistant, Message, MessageRole


class RagnaDemoAssistant(Assistant):
Expand All @@ -22,11 +21,11 @@ class RagnaDemoAssistant(Assistant):
def display_name(cls) -> str:
return "Ragna/DemoAssistant"

def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
if re.search("markdown", prompt, re.IGNORECASE):
def answer(self, messages: list[Message]) -> Iterator[str]:
if "markdown" in messages[-1].content.lower():
yield self._markdown_answer()
else:
yield self._default_answer(prompt, sources)
yield self._default_answer(messages)

def _markdown_answer(self) -> str:
return textwrap.dedent(
Expand All @@ -39,7 +38,8 @@ def _markdown_answer(self) -> str:
"""
).strip()

def _default_answer(self, prompt: str, sources: list[Source]) -> str:
def _default_answer(self, messages: list[Message]) -> str:
prompt, sources = (message := messages[-1]).content, message.sources
sources_display = []
for source in sources:
source_display = f"- {source.document.name}"
Expand All @@ -50,13 +50,16 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str:
if len(sources) > 3:
sources_display.append("[...]")

n_messages = len([m for m in messages if m.role == MessageRole.USER])
return (
textwrap.dedent(
"""
I'm a demo assistant and can be used to try Ragnas workflow.
I'm a demo assistant and can be used to try Ragna's workflow.
I will only mirror back my inputs.
So far I have received {n_messages} messages.
Your prompt was:
Your last prompt was:
> {prompt}
Expand All @@ -66,5 +69,10 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str:
"""
)
.strip()
.format(name=str(self), prompt=prompt, sources="\n".join(sources_display))
.format(
name=str(self),
n_messages=n_messages,
prompt=prompt,
sources="\n".join(sources_display),
)
)
5 changes: 3 additions & 2 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -26,8 +26,9 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for chunk in self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Message, RagnaException

from ._http_api import HttpStreamingProtocol
from ._openai import OpenaiLikeHttpApiAssistant
Expand Down Expand Up @@ -30,8 +30,9 @@ def _url(self) -> str:
return f"{base_url}/api/chat"

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import Any, AsyncIterator, Optional, cast

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -55,8 +55,9 @@ def _stream(
return self._call_api("POST", self._url, headers=headers, json=json_)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
choice = data["choices"][0]
if choice["finish_reason"] is not None:
Expand Down
10 changes: 5 additions & 5 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
...


class MessageRole(enum.Enum):
class MessageRole(str, enum.Enum):
"""Message role
Attributes:
Expand Down Expand Up @@ -238,12 +238,12 @@ class Assistant(Component, abc.ABC):
__ragna_protocol_methods__ = ["answer"]

@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
"""Answer a prompt given some sources.
def answer(self, messages: list[Message]) -> Iterator[str]:
"""Answer a prompt given the chat history.
Args:
prompt: Prompt to be answered.
sources: Sources to use when answering answer the prompt.
messages: List of messages in the chat history. The last item is the current
user prompt and has the relevant sources attached to it.
Returns:
Answer.
Expand Down
7 changes: 4 additions & 3 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,13 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
detail=RagnaException.EVENT,
)

self._messages.append(Message(content=prompt, role=MessageRole.USER))

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)

question = Message(content=prompt, role=MessageRole.USER, sources=sources)
self._messages.append(question)

answer = Message(
content=self._run_gen(self.assistant.answer, prompt, sources),
content=self._run_gen(self.assistant.answer, self._messages.copy()),
role=MessageRole.ASSISTANT,
sources=sources,
)
Expand Down
14 changes: 7 additions & 7 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,15 @@ async def answer(
) -> schemas.Message:
with get_session() as session:
chat = database.get_chat(session, user=user, id=id)
chat.messages.append(
schemas.Message(content=prompt, role=ragna.core.MessageRole.USER)
)
core_chat = schema_to_core_chat(session, user=user, chat=chat)

core_answer = await core_chat.answer(prompt, stream=stream)
sources = [schemas.Source.from_core(source) for source in core_answer.sources]
chat.messages.append(
schemas.Message(
content=prompt, role=ragna.core.MessageRole.USER, sources=sources
)
)

if stream:

Expand All @@ -303,10 +306,7 @@ async def message_chunks() -> AsyncIterator[BaseModel]:
answer = schemas.Message(
content=content_chunk,
role=core_answer.role,
sources=[
schemas.Source.from_core(source)
for source in core_answer.sources
],
sources=sources,
)
yield answer

Expand Down
5 changes: 3 additions & 2 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ragna import assistants
from ragna._compat import anext
from ragna.assistants._http_api import HttpApiAssistant
from ragna.core import RagnaException
from ragna.core import Message, RagnaException
from tests.utils import skip_on_windows

HTTP_API_ASSISTANTS = [
Expand All @@ -25,7 +25,8 @@
async def test_api_call_error_smoke(mocker, assistant):
mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"})

chunks = assistant().answer(prompt="?", sources=[])
messages = [Message(content="?", sources=[])]
chunks = assistant().answer(messages)

with pytest.raises(RagnaException, match="API call failed"):
await anext(chunks)
6 changes: 2 additions & 4 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_params_validation_missing(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
messages,
bool_param: bool,
int_param: int,
float_param: float,
Expand All @@ -65,8 +64,7 @@ def test_params_validation_wrong_type(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
messages,
bool_param: bool,
int_param: int,
float_param: float,
Expand Down
6 changes: 4 additions & 2 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from ragna.deploy import Config
from ragna.deploy._api import app
from tests.deploy.utils import TestAssistant, authenticate_with_api
from tests.utils import skip_on_windows


@skip_on_windows
@pytest.mark.parametrize("multiple_answer_chunks", [True, False])
@pytest.mark.parametrize("stream_answer", [True, False])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
Expand Down Expand Up @@ -107,12 +109,12 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):

chat = client.get(f"/chats/{chat['id']}").raise_for_status().json()
assert len(chat["messages"]) == 3
assert chat["messages"][-1] == message
assert (
chat["messages"][-2]["role"] == "user"
and chat["messages"][-2]["sources"] == []
and chat["messages"][-2]["sources"] == message["sources"]
and chat["messages"][-2]["content"] == prompt
)
assert chat["messages"][-1] == message

client.delete(f"/chats/{chat['id']}").raise_for_status()
assert client.get("/chats").raise_for_status().json() == []
Loading

0 comments on commit 55f7fc5

Please sign in to comment.