From 749862c8a7aa4992c8bb1465521ba14156a9c711 Mon Sep 17 00:00:00 2001 From: Sanket Sudake Date: Sat, 24 Aug 2024 11:28:04 +0530 Subject: [PATCH] Code cleanup and changes (#13) * Few changes - Moved file upload to sidebar - Add st cache for API clients - reorganized multi rag script Signed-off-by: Sanket * Make responses faster with streaming Signed-off-by: Sanket --------- Signed-off-by: Sanket --- app.py | 12 +++-- multi_tenant_rag.py | 104 +++++++++++++++++++++++--------------------- 2 files changed, 63 insertions(+), 53 deletions(-) diff --git a/app.py b/app.py index bf836b5..0bee20b 100644 --- a/app.py +++ b/app.py @@ -55,6 +55,7 @@ def authenticate(): # Set up Chroma DB client +@st.cache_resource def setup_chroma_client(): client = chromadb.HttpClient( host="http://{host}:{port}".format( @@ -67,6 +68,7 @@ def setup_chroma_client(): # Set up Chroma embedding function +@st.cache_resource def hf_embedding_server(): _embedding_function = HuggingFaceEmbeddingServer( url="http://{host}:{port}/embed".format( @@ -77,6 +79,7 @@ def hf_embedding_server(): # Set up HuggingFaceEndpoint model +@st.cache_resource def setup_huggingface_endpoint(model_id): llm = HuggingFaceEndpoint( endpoint_url="http://{host}:{port}".format( @@ -86,6 +89,7 @@ def setup_huggingface_endpoint(model_id): task="conversational", stop_sequences=[ "<|im_end|>", + "<|eot_id|>", "{your_token}".format( your_token=os.getenv("STOP_TOKEN", "<|end_of_text|>") ), @@ -97,6 +101,8 @@ def setup_huggingface_endpoint(model_id): return model +# Set up Portkey integrated model +@st.cache_resource def setup_portkey_integrated_model(): from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL from langchain_openai import ChatOpenAI @@ -118,6 +124,7 @@ def setup_portkey_integrated_model(): # Set up HuggingFaceEndpointEmbeddings embedder +@st.cache_resource def setup_huggingface_embeddings(): embedder = HuggingFaceEndpointEmbeddings( model="http://{host}:{port}".format( @@ -128,6 +135,7 @@ def setup_huggingface_embeddings(): return embedder +@st.cache_resource def load_prompt_and_system_ins( template_file_path="templates/prompt_template.tmpl", template=None ): @@ -235,9 +243,7 @@ def query_docs( | StrOutputParser() ) - answer = rag_chain.invoke({"question": question, "chat_history": chat_history}) - return answer - + return rag_chain.stream({"question": question, "chat_history": chat_history}) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index fb4e6d1..7c4bdd2 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -104,68 +104,72 @@ def main(): with st.chat_message(message["role"]): st.markdown(message["content"]) - if user_id: + if not user_id: + st.error("Please login to continue") + return - collection = client.get_or_create_collection( - f"user-collection-{user_id}", embedding_function=chroma_embeddings - ) + collection = client.get_or_create_collection( + f"user-collection-{user_id}", embedding_function=chroma_embeddings + ) - uploaded_file = st.file_uploader("Upload a document", type=["pdf"]) + uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"]) + question = st.chat_input("Chat with your doc") - rag = MultiTenantRAG(user_id, collection.name, client) + rag = MultiTenantRAG(user_id, collection.name, client) - # prompt = hub.pull("rlm/rag-prompt") + # prompt = hub.pull("rlm/rag-prompt") + + vectorstore = Chroma( + embedding_function=embedding_svc, + collection_name=collection.name, + client=client, + ) - vectorstore = Chroma( - embedding_function=embedding_svc, - collection_name=collection.name, - client=client, + if uploaded_file: + document = rag.load_documents(uploaded_file) + chunks = rag.chunk_doc(document) + rag.insert_embeddings( + chunks=chunks, + chroma_embedding_function=chroma_embeddings, + # embedder=embedding_svc, + batch_size=32, ) - if uploaded_file: - document = rag.load_documents(uploaded_file) - chunks = rag.chunk_doc(document) - rag.insert_embeddings( - chunks=chunks, - chroma_embedding_function=chroma_embeddings, - # embedder=embedding_svc, - batch_size=32, + if question: + st.chat_message("user").markdown(question) + with st.spinner(): + answer = rag.query_docs( + model=llm, + question=question, + vector_store=vectorstore, + prompt=prompt, + chat_history=chat_history, + use_reranker=False, ) + with st.chat_message("assistant"): + answer = st.write_stream(answer) + # print( + # "####\n#### Answer received by querying docs: " + answer + "\n####" + # ) - if question := st.chat_input("Chat with your doc"): - st.chat_message("user").markdown(question) - with st.spinner(): - answer = rag.query_docs( - model=llm, - question=question, - vector_store=vectorstore, - prompt=prompt, - chat_history=chat_history, - use_reranker=False, - ) - # print( - # "####\n#### Answer received by querying docs: " + answer + "\n####" - # ) - - answer_with_reranker = rag.query_docs( - model=llm, - question=question, - vector_store=vectorstore, - prompt=prompt, - chat_history=chat_history, - use_reranker=True, - ) - - st.chat_message("assistant").markdown(answer) - st.chat_message("assistant").markdown(answer_with_reranker) - - chat_history.append({"role": "user", "content": question}) - chat_history.append({"role": "assistant", "content": answer}) - st.session_state["chat_history"] = chat_history + # answer_with_reranker = rag.query_docs( + # model=llm, + # question=question, + # vector_store=vectorstore, + # prompt=prompt, + # chat_history=chat_history, + # use_reranker=True, + # ) + # st.chat_message("assistant").markdown(answer) + # st.chat_message("assistant").markdown(answer_with_reranker) + + chat_history.append({"role": "user", "content": question}) + chat_history.append({"role": "assistant", "content": answer}) + st.session_state["chat_history"] = chat_history -if __name__ == "__main__": +if __name__ == "__main__": authenticator = authenticate("login") if st.session_state["authentication_status"]: authenticator.logout()