Skip to content

Commit

Permalink
Refactor UI state management (#1191)
Browse files Browse the repository at this point in the history
* Added logs at generation of the UI, and generate the UI in an object
* Make ingest script more verbose in case of an error at ingestion time
* Removed the explicit state in the UI containing ingested files
* Make script of ingestion a bit more verbose by displaying stack traces
* Change the browser tab title of privateGPT ui to `My Private GPT`
  • Loading branch information
lopagela authored Nov 10, 2023
1 parent 55e626e commit a666fd5
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 145 deletions.
8 changes: 6 additions & 2 deletions private_gpt/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any

import llama_index
Expand All @@ -14,6 +15,8 @@
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import settings

logger = logging.getLogger(__name__)

# Add LlamaIndex simple observability
llama_index.set_global_handler("simple")

Expand Down Expand Up @@ -103,6 +106,7 @@ def custom_openapi() -> dict[str, Any]:


if settings.ui.enabled:
from private_gpt.ui.ui import mount_in_app
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi

mount_in_app(app)
PrivateGptUi().mount_in_app(app)
309 changes: 168 additions & 141 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This file should be imported only and only if you want to run the UI locally."""
import itertools
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Any, TextIO
Expand All @@ -15,151 +17,176 @@
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg

ingest_service = root_injector.get(IngestService)
chat_service = root_injector.get(ChatService)
chunks_service = root_injector.get(ChunksService)


def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]:
full_response: str = ""
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response

def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT),
logger = logging.getLogger(__name__)


UI_TAB_TITLE = "My Private GPT"


class PrivateGptUi:
def __init__(self) -> None:
self._ingest_service = root_injector.get(IngestService)
self._chat_service = root_injector.get(ChatService)
self._chunks_service = root_injector.get(ChunksService)

# Cache the UI blocks
self._ui_block = None

def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]:
full_response: str = ""
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response

def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
content=interaction[1], role=MessageRole.ASSISTANT
),
]
for interaction in history
]
for interaction in history
]
)
)
)

# max 20 messages to try to avoid context overflow
return history_messages[:20]

new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
match mode:
case "Query Docs":
query_stream = chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)

case "LLM Chat":
llm_stream = chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)

case "Search in Docs":
response = chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)

yield "\n\n\n".join(
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} "
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}"
for index, chunk in enumerate(response, start=1)
)


def _list_ingested_files() -> list[str]:
files = set()
for ingested_document in ingest_service.list_ingested():
if ingested_document.doc_metadata is not None:
files.add(
ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]"
# max 20 messages to try to avoid context overflow
return history_messages[:20]

new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
match mode:
case "Query Docs":
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)

case "LLM Chat":
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)

case "Search in Docs":
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)

yield "\n\n\n".join(
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} "
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}"
for index, chunk in enumerate(response, start=1)
)

def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
)
return list(files)


# Global state
_uploaded_file_list = [[row] for row in _list_ingested_files()]


def _upload_file(file: TextIO) -> list[list[str]]:
path = Path(file.name)
ingest_service.ingest(file_name=path.name, file_data=path)
_uploaded_file_list.append([path.name])
return _uploaded_file_list


with gr.Blocks(
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")

with gr.Row():
with gr.Column(scale=3, variant="compact"):
mode = gr.Radio(
["Query Docs", "Search in Docs", "LLM Chat"],
label="Mode",
value="Query Docs",
)
upload_button = gr.components.UploadButton(
"Upload a File",
type="file",
file_count="single",
size="sm",
)
ingested_dataset = gr.List(
_uploaded_file_list,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
_upload_file, inputs=upload_button, outputs=ingested_dataset
)
ingested_dataset.render()
with gr.Column(scale=7):
chatbot = gr.ChatInterface(
_chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}",
show_copy_button=True,
render=False,
avatar_images=(
None,
"https://lh3.googleusercontent.com/drive-viewer/AK7aPa"
"AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84"
"IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560",
),
),
additional_inputs=[mode, upload_button],
)


def mount_in_app(app: FastAPI) -> None:
blocks.queue()
gr.mount_gradio_app(app, blocks, path=settings.ui.path)
files.add(file_name)
return [[row] for row in files]

def _upload_file(self, file: TextIO) -> None:
path = Path(file.name)
self._ingest_service.ingest(file_name=path.name, file_data=path)

def _build_ui_blocks(self) -> gr.Blocks:
logger.debug("Creating the UI blocks")
with gr.Blocks(
title=UI_TAB_TITLE,
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")

with gr.Row():
with gr.Column(scale=3, variant="compact"):
mode = gr.Radio(
["Query Docs", "Search in Docs", "LLM Chat"],
label="Mode",
value="Query Docs",
)
upload_button = gr.components.UploadButton(
"Upload a File",
type="file",
file_count="single",
size="sm",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
self._upload_file,
inputs=upload_button,
outputs=ingested_dataset,
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
with gr.Column(scale=7):
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}",
show_copy_button=True,
render=False,
avatar_images=(
None,
"https://lh3.googleusercontent.com/drive-viewer/AK7aPa"
"AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84"
"IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560",
),
),
additional_inputs=[mode, upload_button],
)
return blocks

def get_ui_blocks(self) -> gr.Blocks:
if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block

def mount_in_app(self, app: FastAPI) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
base_path = settings.ui.path
logger.info("Mounting the gradio UI, at path=%s", base_path)
gr.mount_gradio_app(app, blocks, path=base_path)


if __name__ == "__main__":
blocks.queue()
blocks.launch(debug=False, show_api=False)
ui = PrivateGptUi()
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)
6 changes: 4 additions & 2 deletions scripts/ingest_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def _do_ingest(changed_path: Path) -> None:
logger.info(f"Started ingesting {changed_path}")
ingest_service.ingest(changed_path.name, changed_path)
logger.info(f"Completed ingesting {changed_path}")
except Exception as e:
logger.error(f"Failed to ingest document: {changed_path}. Error: {e}")
except Exception:
logger.exception(
f"Failed to ingest document: {changed_path}, find the exception attached"
)


path = Path(args.folder)
Expand Down

0 comments on commit a666fd5

Please sign in to comment.