From a666fd5b73965986ace8e234271a4777e6794767 Mon Sep 17 00:00:00 2001 From: lopagela Date: Fri, 10 Nov 2023 10:42:43 +0100 Subject: [PATCH] Refactor UI state management (#1191) * Added logs at generation of the UI, and generate the UI in an object * Make ingest script more verbose in case of an error at ingestion time * Removed the explicit state in the UI containing ingested files * Make script of ingestion a bit more verbose by displaying stack traces * Change the browser tab title of privateGPT ui to `My Private GPT` --- private_gpt/main.py | 8 +- private_gpt/ui/ui.py | 309 +++++++++++++++++++++------------------ scripts/ingest_folder.py | 6 +- 3 files changed, 178 insertions(+), 145 deletions(-) diff --git a/private_gpt/main.py b/private_gpt/main.py index f5f193e2f..2cf6bf90a 100644 --- a/private_gpt/main.py +++ b/private_gpt/main.py @@ -1,4 +1,5 @@ """FastAPI app creation, logger configuration and main API routes.""" +import logging from typing import Any import llama_index @@ -14,6 +15,8 @@ from private_gpt.server.ingest.ingest_router import ingest_router from private_gpt.settings.settings import settings +logger = logging.getLogger(__name__) + # Add LlamaIndex simple observability llama_index.set_global_handler("simple") @@ -103,6 +106,7 @@ def custom_openapi() -> dict[str, Any]: if settings.ui.enabled: - from private_gpt.ui.ui import mount_in_app + logger.debug("Importing the UI module") + from private_gpt.ui.ui import PrivateGptUi - mount_in_app(app) + PrivateGptUi().mount_in_app(app) diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index b66223bd9..430dbbf6b 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -1,4 +1,6 @@ +"""This file should be imported only and only if you want to run the UI locally.""" import itertools +import logging from collections.abc import Iterable from pathlib import Path from typing import Any, TextIO @@ -15,151 +17,176 @@ from private_gpt.settings.settings import settings from private_gpt.ui.images import logo_svg -ingest_service = root_injector.get(IngestService) -chat_service = root_injector.get(ChatService) -chunks_service = root_injector.get(ChunksService) - - -def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: - def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: - full_response: str = "" - for delta in stream: - if isinstance(delta, str): - full_response += str(delta) - elif isinstance(delta, ChatResponse): - full_response += delta.delta or "" - yield full_response - - def build_history() -> list[ChatMessage]: - history_messages: list[ChatMessage] = list( - itertools.chain( - *[ - [ - ChatMessage(content=interaction[0], role=MessageRole.USER), - ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT), +logger = logging.getLogger(__name__) + + +UI_TAB_TITLE = "My Private GPT" + + +class PrivateGptUi: + def __init__(self) -> None: + self._ingest_service = root_injector.get(IngestService) + self._chat_service = root_injector.get(ChatService) + self._chunks_service = root_injector.get(ChunksService) + + # Cache the UI blocks + self._ui_block = None + + def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: + def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: + full_response: str = "" + for delta in stream: + if isinstance(delta, str): + full_response += str(delta) + elif isinstance(delta, ChatResponse): + full_response += delta.delta or "" + yield full_response + + def build_history() -> list[ChatMessage]: + history_messages: list[ChatMessage] = list( + itertools.chain( + *[ + [ + ChatMessage(content=interaction[0], role=MessageRole.USER), + ChatMessage( + content=interaction[1], role=MessageRole.ASSISTANT + ), + ] + for interaction in history ] - for interaction in history - ] + ) ) - ) - - # max 20 messages to try to avoid context overflow - return history_messages[:20] - - new_message = ChatMessage(content=message, role=MessageRole.USER) - all_messages = [*build_history(), new_message] - match mode: - case "Query Docs": - query_stream = chat_service.stream_chat( - messages=all_messages, - use_context=True, - ) - yield from yield_deltas(query_stream) - - case "LLM Chat": - llm_stream = chat_service.stream_chat( - messages=all_messages, - use_context=False, - ) - yield from yield_deltas(llm_stream) - - case "Search in Docs": - response = chunks_service.retrieve_relevant( - text=message, limit=4, prev_next_chunks=0 - ) - - yield "\n\n\n".join( - f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " - f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " - f"{chunk.text}" - for index, chunk in enumerate(response, start=1) - ) - -def _list_ingested_files() -> list[str]: - files = set() - for ingested_document in ingest_service.list_ingested(): - if ingested_document.doc_metadata is not None: - files.add( - ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]" + # max 20 messages to try to avoid context overflow + return history_messages[:20] + + new_message = ChatMessage(content=message, role=MessageRole.USER) + all_messages = [*build_history(), new_message] + match mode: + case "Query Docs": + query_stream = self._chat_service.stream_chat( + messages=all_messages, + use_context=True, + ) + yield from yield_deltas(query_stream) + + case "LLM Chat": + llm_stream = self._chat_service.stream_chat( + messages=all_messages, + use_context=False, + ) + yield from yield_deltas(llm_stream) + + case "Search in Docs": + response = self._chunks_service.retrieve_relevant( + text=message, limit=4, prev_next_chunks=0 + ) + + yield "\n\n\n".join( + f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " + f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " + f"{chunk.text}" + for index, chunk in enumerate(response, start=1) + ) + + def _list_ingested_files(self) -> list[list[str]]: + files = set() + for ingested_document in self._ingest_service.list_ingested(): + if ingested_document.doc_metadata is None: + # Skipping documents without metadata + continue + file_name = ingested_document.doc_metadata.get( + "file_name", "[FILE NAME MISSING]" ) - return list(files) - - -# Global state -_uploaded_file_list = [[row] for row in _list_ingested_files()] - - -def _upload_file(file: TextIO) -> list[list[str]]: - path = Path(file.name) - ingest_service.ingest(file_name=path.name, file_data=path) - _uploaded_file_list.append([path.name]) - return _uploaded_file_list - - -with gr.Blocks( - theme=gr.themes.Soft(primary_hue=slate), - css=".logo { " - "display:flex;" - "background-color: #C7BAFF;" - "height: 80px;" - "border-radius: 8px;" - "align-content: center;" - "justify-content: center;" - "align-items: center;" - "}" - ".logo img { height: 25% }", -) as blocks: - with gr.Row(): - gr.HTML(f"