Skip to content

Commit

Permalink
Code cleanup and changes (#13)
Browse files Browse the repository at this point in the history
* Few changes

- Moved file upload to sidebar
- Add st cache for API clients
- reorganized multi rag script

Signed-off-by: Sanket <[email protected]>

* Make responses faster with streaming

Signed-off-by: Sanket <[email protected]>

---------

Signed-off-by: Sanket <[email protected]>
  • Loading branch information
sanketsudake authored Aug 24, 2024
1 parent 038864a commit 749862c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 53 deletions.
12 changes: 9 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def authenticate():


# Set up Chroma DB client
@st.cache_resource
def setup_chroma_client():
client = chromadb.HttpClient(
host="http://{host}:{port}".format(
Expand All @@ -67,6 +68,7 @@ def setup_chroma_client():


# Set up Chroma embedding function
@st.cache_resource
def hf_embedding_server():
_embedding_function = HuggingFaceEmbeddingServer(
url="http://{host}:{port}/embed".format(
Expand All @@ -77,6 +79,7 @@ def hf_embedding_server():


# Set up HuggingFaceEndpoint model
@st.cache_resource
def setup_huggingface_endpoint(model_id):
llm = HuggingFaceEndpoint(
endpoint_url="http://{host}:{port}".format(
Expand All @@ -86,6 +89,7 @@ def setup_huggingface_endpoint(model_id):
task="conversational",
stop_sequences=[
"<|im_end|>",
"<|eot_id|>",
"{your_token}".format(
your_token=os.getenv("STOP_TOKEN", "<|end_of_text|>")
),
Expand All @@ -97,6 +101,8 @@ def setup_huggingface_endpoint(model_id):
return model


# Set up Portkey integrated model
@st.cache_resource
def setup_portkey_integrated_model():
from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL
from langchain_openai import ChatOpenAI
Expand All @@ -118,6 +124,7 @@ def setup_portkey_integrated_model():


# Set up HuggingFaceEndpointEmbeddings embedder
@st.cache_resource
def setup_huggingface_embeddings():
embedder = HuggingFaceEndpointEmbeddings(
model="http://{host}:{port}".format(
Expand All @@ -128,6 +135,7 @@ def setup_huggingface_embeddings():
return embedder


@st.cache_resource
def load_prompt_and_system_ins(
template_file_path="templates/prompt_template.tmpl", template=None
):
Expand Down Expand Up @@ -235,9 +243,7 @@ def query_docs(
| StrOutputParser()
)

answer = rag_chain.invoke({"question": question, "chat_history": chat_history})
return answer

return rag_chain.stream({"question": question, "chat_history": chat_history})

def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
Expand Down
104 changes: 54 additions & 50 deletions multi_tenant_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,68 +104,72 @@ def main():
with st.chat_message(message["role"]):
st.markdown(message["content"])

if user_id:
if not user_id:
st.error("Please login to continue")
return

collection = client.get_or_create_collection(
f"user-collection-{user_id}", embedding_function=chroma_embeddings
)
collection = client.get_or_create_collection(
f"user-collection-{user_id}", embedding_function=chroma_embeddings
)

uploaded_file = st.file_uploader("Upload a document", type=["pdf"])
uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"])
question = st.chat_input("Chat with your doc")

rag = MultiTenantRAG(user_id, collection.name, client)
rag = MultiTenantRAG(user_id, collection.name, client)

# prompt = hub.pull("rlm/rag-prompt")
# prompt = hub.pull("rlm/rag-prompt")

vectorstore = Chroma(
embedding_function=embedding_svc,
collection_name=collection.name,
client=client,
)

vectorstore = Chroma(
embedding_function=embedding_svc,
collection_name=collection.name,
client=client,
if uploaded_file:
document = rag.load_documents(uploaded_file)
chunks = rag.chunk_doc(document)
rag.insert_embeddings(
chunks=chunks,
chroma_embedding_function=chroma_embeddings,
# embedder=embedding_svc,
batch_size=32,
)

if uploaded_file:
document = rag.load_documents(uploaded_file)
chunks = rag.chunk_doc(document)
rag.insert_embeddings(
chunks=chunks,
chroma_embedding_function=chroma_embeddings,
# embedder=embedding_svc,
batch_size=32,
if question:
st.chat_message("user").markdown(question)
with st.spinner():
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
chat_history=chat_history,
use_reranker=False,
)
with st.chat_message("assistant"):
answer = st.write_stream(answer)
# print(
# "####\n#### Answer received by querying docs: " + answer + "\n####"
# )

if question := st.chat_input("Chat with your doc"):
st.chat_message("user").markdown(question)
with st.spinner():
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
chat_history=chat_history,
use_reranker=False,
)
# print(
# "####\n#### Answer received by querying docs: " + answer + "\n####"
# )

answer_with_reranker = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
chat_history=chat_history,
use_reranker=True,
)

st.chat_message("assistant").markdown(answer)
st.chat_message("assistant").markdown(answer_with_reranker)

chat_history.append({"role": "user", "content": question})
chat_history.append({"role": "assistant", "content": answer})
st.session_state["chat_history"] = chat_history
# answer_with_reranker = rag.query_docs(
# model=llm,
# question=question,
# vector_store=vectorstore,
# prompt=prompt,
# chat_history=chat_history,
# use_reranker=True,
# )

# st.chat_message("assistant").markdown(answer)
# st.chat_message("assistant").markdown(answer_with_reranker)

chat_history.append({"role": "user", "content": question})
chat_history.append({"role": "assistant", "content": answer})
st.session_state["chat_history"] = chat_history

if __name__ == "__main__":

if __name__ == "__main__":
authenticator = authenticate("login")
if st.session_state["authentication_status"]:
authenticator.logout()
Expand Down

0 comments on commit 749862c

Please sign in to comment.