Skip to content

Commit

Permalink
Allow passing a system prompt (#1318)
Browse files Browse the repository at this point in the history
  • Loading branch information
imartinez authored Nov 29, 2023
1 parent 9c192dd commit 64ed9cd
Show file tree
Hide file tree
Showing 6 changed files with 1,037 additions and 947 deletions.
1,778 changes: 883 additions & 895 deletions fern/openapi/openapi.json

Large diffs are not rendered by default.

19 changes: 16 additions & 3 deletions private_gpt/server/chat/chat_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ class ChatBody(BaseModel):
"examples": [
{
"messages": [
{
"role": "system",
"content": "You are a rapper. Always answer with a rap.",
},
{
"role": "user",
"content": "How do you fry an egg?",
}
},
],
"stream": False,
"use_context": True,
Expand All @@ -56,6 +60,9 @@ def chat_completion(
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response.
Optionally include an initial `role: system` message to influence the way
the LLM answers.
If `use_context` is set to `true`, the model will use context coming
from the ingested documents to create the response. The documents being used can
be filtered using the `context_filter` and passing the document IDs to be used.
Expand All @@ -79,7 +86,9 @@ def chat_completion(
]
if body.stream:
completion_gen = service.stream_chat(
all_messages, body.use_context, body.context_filter
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return StreamingResponse(
to_openai_sse_stream(
Expand All @@ -89,7 +98,11 @@ def chat_completion(
media_type="text/event-stream",
)
else:
completion = service.chat(all_messages, body.use_context, body.context_filter)
completion = service.chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return to_openai_response(
completion.response, completion.sources if body.include_sources else None
)
159 changes: 113 additions & 46 deletions private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass

from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.chat_engine import ContextChatEngine
from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.chat_engine.types import (
BaseChatEngine,
)
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.llm_predictor.utils import stream_chat_response_to_tokens
from llama_index.llms import ChatMessage
from llama_index.llms import ChatMessage, MessageRole
from llama_index.types import TokenGen
from pydantic import BaseModel

Expand All @@ -30,6 +31,40 @@ class CompletionGen(BaseModel):
sources: list[Chunk] | None = None


@dataclass
class ChatEngineInput:
system_message: ChatMessage | None = None
last_message: ChatMessage | None = None
chat_history: list[ChatMessage] | None = None

@classmethod
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
# Detect if there is a system message, extract the last message and chat history
system_message = (
messages[0]
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
else None
)
last_message = (
messages[-1]
if len(messages) > 0 and messages[-1].role == MessageRole.USER
else None
)
# Remove from messages list the system message and last message,
# if they exist. The rest is the chat history.
if system_message:
messages.pop(0)
if last_message:
messages.pop(-1)
chat_history = messages if len(messages) > 0 else None

return cls(
system_message=system_message,
last_message=last_message,
chat_history=chat_history,
)


@singleton
class ChatService:
@inject
Expand Down Expand Up @@ -58,43 +93,63 @@ def __init__(
)

def _chat_engine(
self, context_filter: ContextFilter | None = None
) -> BaseChatEngine:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter
)
return ContextChatEngine.from_defaults(
retriever=vector_index_retriever,
service_context=self.service_context,
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
],
)

def stream_chat(
self,
messages: list[ChatMessage],
system_prompt: str | None = None,
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> CompletionGen:
) -> BaseChatEngine:
if use_context:
last_message = messages[-1].content
chat_engine = self._chat_engine(context_filter=context_filter)
streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "",
chat_history=messages[:-1],
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter
)
sources = [
Chunk.from_node(node) for node in streaming_response.source_nodes
]
completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
service_context=self.service_context,
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
],
)
else:
stream = self.llm_service.llm.stream_chat(messages)
completion_gen = CompletionGen(
response=stream_chat_response_to_tokens(stream)
return SimpleChatEngine.from_defaults(
system_prompt=system_prompt,
service_context=self.service_context,
)

def stream_chat(
self,
messages: list[ChatMessage],
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> CompletionGen:
chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = (
chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)

chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources
)
return completion_gen

def chat(
Expand All @@ -103,18 +158,30 @@ def chat(
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> Completion:
if use_context:
last_message = messages[-1].content
chat_engine = self._chat_engine(context_filter=context_filter)
wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "",
chat_history=messages[:-1],
)
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
else:
chat_response = self.llm_service.llm.chat(messages)
response_content = chat_response.message.content
response = response_content if response_content is not None else ""
completion = Completion(response=response)
chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = (
chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)

chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
return completion
16 changes: 13 additions & 3 deletions private_gpt/server/completions/completions_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class CompletionsBody(BaseModel):
prompt: str
system_prompt: str | None = None
use_context: bool = False
context_filter: ContextFilter | None = None
include_sources: bool = True
Expand All @@ -25,6 +26,7 @@ class CompletionsBody(BaseModel):
"examples": [
{
"prompt": "How do you fry an egg?",
"system_prompt": "You are a rapper. Always answer with a rap.",
"stream": False,
"use_context": False,
"include_sources": False,
Expand All @@ -46,7 +48,11 @@ def prompt_completion(
) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context`
Given a prompt, the model will return one predicted completion.
Optionally include a `system_prompt` to influence the way the LLM answers.
If `use_context`
is set to `true`, the model will use context coming from the ingested documents
to create the response. The documents being used can be filtered using the
`context_filter` and passing the document IDs to be used. Ingested documents IDs
Expand All @@ -64,9 +70,13 @@ def prompt_completion(
"finish_reason":null}]}
```
"""
message = OpenAIMessage(content=body.prompt, role="user")
messages = [OpenAIMessage(content=body.prompt, role="user")]
# If system prompt is passed, create a fake message with the system prompt.
if body.system_prompt:
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))

chat_body = ChatBody(
messages=[message],
messages=messages,
use_context=body.use_context,
stream=body.stream,
include_sources=body.include_sources,
Expand Down
11 changes: 11 additions & 0 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def build_history() -> list[ChatMessage]:
all_messages = [*build_history(), new_message]
match mode:
case "Query Docs":
# Add a system message to force the behaviour of the LLM
# to answer only questions about the provided context.
all_messages.insert(
0,
ChatMessage(
content="You can only answer questions about the provided context. If you know the answer "
"but it is not based in the provided context, don't provide the answer, just state "
"the answer is not in the context provided.",
role=MessageRole.SYSTEM,
),
)
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
Expand Down
1 change: 1 addition & 0 deletions settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ui:

llm:
mode: local

embedding:
# Should be matching the value above in most cases
mode: local
Expand Down

0 comments on commit 64ed9cd

Please sign in to comment.