diff --git a/mind_palace/app.py b/mind_palace/app.py index 0d01673..5261cae 100644 --- a/mind_palace/app.py +++ b/mind_palace/app.py @@ -3,11 +3,15 @@ import openai import streamlit as st import welcome +from itune import MultiArmedBandit, Tune from llama_index.query_engine import CitationQueryEngine openai.api_key = st.secrets.openai_key xml_dir = "./resources/xmls/dennis-oct-10/" gpt_model = "gpt-3.5-turbo" +itune = Tune(strategy=MultiArmedBandit()) + +itune.load() st.set_page_config(page_title="Q&A with Dennis's PDFs") st.title("Q&A with Dennis's PDFs 💬") @@ -28,7 +32,11 @@ def load_nodes_and_index(xml_dir, model): nodes, vector_index = load_nodes_and_index(xml_dir, gpt_model) -query_engine = CitationQueryEngine.from_args(index=vector_index, verbose=True) +query_engine = CitationQueryEngine.from_args( + index=vector_index, + similarity_top_k=itune.choose(similarity_top_k=[3, 5]), + verbose=True, +) @st.cache_data( @@ -54,21 +62,34 @@ def get_welcome_message(abstracts): "content": "Ask me a question about these papers.", }, ] -else: +# if this refresh is not triggered by user button press, reset the chat messages +elif not st.session_state.rating_button_pressed: + print("clearing chat messages!") st.session_state.messages = [] +# this needs to be reset after checking if the messaages should be cleared +st.session_state.rating_button_pressed = False if prompt := st.chat_input( "Your question" ): # Prompt for user input and save to chat history st.session_state.messages.append({"role": "user", "content": prompt}) + for message in st.session_state.messages: # Display the prior chat messages with st.chat_message(message["role"]): st.write(message["content"]) -# If last message is not from assistant, generate a new response -if st.session_state.messages[-1]["role"] != "assistant": + +def user_clicked_rating(is_good_response): + print(f"user thumbs {'up' if is_good_response else 'down'}") + itune.register_outcome(is_good_response) + itune.save() + st.session_state.rating_button_pressed = True + + +# If last message is from user, generate a new response +if st.session_state.messages[-1]["role"] == "user": with st.chat_message("assistant"): with st.spinner("Thinking..."): response = query_engine.query(prompt) @@ -77,6 +98,24 @@ def get_welcome_message(abstracts): message = {"role": "assistant", "content": response.response} st.session_state.messages.append(message) # Add response to message history + _, col1, col2 = st.columns([7, 1, 1], gap="small") + col1.button( + "👍", + on_click=user_clicked_rating, + args=[True], + help="Good response", + key="good_response", + use_container_width=True, + ) + col2.button( + "👎", + on_click=user_clicked_rating, + args=[False], + help="Bad response", + key="bad_response", + use_container_width=True, + ) + st.markdown("### Sources") for i, source_node in enumerate(response.source_nodes): with st.expander(f"[{i + 1}] {source_node.node.metadata['citation']}"): diff --git a/requirements/build.txt b/requirements/build.txt index ce83e11..ca7bc9e 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -1,3 +1,4 @@ +itune==0.1.1 pypdf==3.16.0 grobid-tei-xml==0.1.3 nltk==3.8.1