diff --git a/mind_palace/app.py b/mind_palace/app.py index 5423d3b..db4265f 100644 --- a/mind_palace/app.py +++ b/mind_palace/app.py @@ -2,6 +2,7 @@ import index import openai import streamlit as st +import welcome from llama_index.query_engine import CitationQueryEngine openai.api_key = st.secrets.openai_key @@ -9,32 +10,33 @@ gpt_model = "gpt-3.5-turbo" st.set_page_config(page_title="Chatting with Steve's PDFs") -st.title("Chat with Steve's 12 PDFs 💬🦙") +st.title("Chat with Steve's PDFs 💬") with st.sidebar: st.markdown("Conversation History") st.text("Coming soon...") -if "messages" not in st.session_state.keys(): # Initialize the chat messages history - st.session_state.messages = [ - {"role": "assistant", "content": "Ask me a question about these PDFs"} - ] - - @st.cache_resource(show_spinner=False) -def load_index(model): +def load_nodes_and_index(xml_dir, model): with st.spinner( text="Loading and indexing the PDFs – hang tight! This should take 1-2 minutes." ): nodes = extract.seed_nodes(xml_dir) vector_index = index.index_nodes(nodes, model) - return vector_index + return nodes, vector_index -vector_index = load_index(gpt_model) +nodes, vector_index = load_nodes_and_index(xml_dir, gpt_model) query_engine = CitationQueryEngine.from_args(index=vector_index, verbose=True) + +if "messages" not in st.session_state.keys(): # Initialize the chat messages history + st.session_state.messages = [ + {"role": "assistant", "content": welcome.get_welcome_message(nodes)} + ] + + if prompt := st.chat_input( "Your question" ): # Prompt for user input and save to chat history diff --git a/mind_palace/docs.py b/mind_palace/docs.py index f942180..c9d9938 100644 --- a/mind_palace/docs.py +++ b/mind_palace/docs.py @@ -1,7 +1,27 @@ +from enum import Enum, auto + import grobid_tei_xml from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode +class Section(Enum): + TITLE = auto() + ABSTRACT = auto() + BODY = auto() + + def __str__(self) -> str: + return self.name.lower() + + +def create_text_node(node_id, text, section: Section, paragraph_number=None): + return TextNode( + text=text, + metadata={"section": str(section), "paragraph_number": paragraph_number}, + excluded_embed_metadata_keys=["section"], + id_=node_id, + ) + + def load_tei_xml(file_path): print(f"Loading {file_path}") with open(file_path, "r") as xml_file: @@ -25,16 +45,14 @@ def cite(xml): def title(xml, doc_id): - return TextNode( - text=xml.header.title, - id_=f"{doc_id}-title", + return create_text_node( + node_id=f"{doc_id}-title", text=xml.header.title, section=Section.TITLE ) def abstract(xml, doc_id): - return TextNode( - text=xml.abstract, - id_=f"{doc_id}-abstract", + return create_text_node( + node_id=f"{doc_id}-abstract", text=xml.abstract, section=Section.ABSTRACT ) @@ -56,12 +74,12 @@ def set_next_relationships(nodes): def body(xml, doc_id): """A naive implementation of body extraction""" - # TODO: Improve body extraction return [ - TextNode( + create_text_node( + node_id=f"{doc_id}-body-paragraph-{index}", text=line, - metadata={"paragraph_number": index + 1}, - id_=f"{doc_id}-body-paragraph-{index}", + section=Section.BODY, + paragraph_number=index + 1, ) for index, line in enumerate(xml.body.split("\n")) ] diff --git a/mind_palace/welcome.py b/mind_palace/welcome.py new file mode 100644 index 0000000..bafe0d6 --- /dev/null +++ b/mind_palace/welcome.py @@ -0,0 +1,2 @@ +def get_welcome_message(nodes): + return "Ask me a question about these PDFs"