-
Notifications
You must be signed in to change notification settings - Fork 0
/
05_RAG_Chatbot_v3.py
240 lines (199 loc) · 8.72 KB
/
05_RAG_Chatbot_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import streamlit as st
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_query import MultiQueryRetriever
import json
from typing import List, Dict, Any
from langchain_community.document_loaders import PyPDFLoader
import os
class PDFQuestionAnsweringSystem:
def __init__(self) -> None:
"""
Initialize the PDFQuestionAnsweringSystem with a fixed storage path for the Chroma vector database.
"""
self.vector_db = None
self.embeddings_data = None
self.llm_model = "mistral"
self.embedding_model = "nomic-embed-text"
self.collection_name = "local-rag"
self.storage_path = "chroma_vector_storage" # Fixed directory for storing the vector database
self.retriever = None
self.history = [] # To keep track of question and answer history
# Ensure the storage path exists
os.makedirs(self.storage_path, exist_ok=True)
def load_pdf(self, file_path: str):
"""
Load a PDF from the specified file path.
Args:
file_path (str): Path to the PDF file.
"""
loader = PyPDFLoader(file_path=file_path)
return loader.load()
def split_text(self, data):
"""
Split the text data into smaller chunks for processing.
Args:
data: Document data from the PDF.
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return text_splitter.split_documents(data)
def add_to_vector_store(self, documents) -> None:
"""
Add the document chunks to the vector store with specified storage path.
Args:
documents: List of document chunks.
"""
# Initialize the Chroma vector store with a specified storage path
self.vector_db = Chroma.from_documents(
documents=documents,
embedding=OllamaEmbeddings(model=self.embedding_model, show_progress=True),
collection_name=self.collection_name,
persist_directory=self.storage_path # Specify the directory for persistent storage
)
def generate_and_store_embeddings(self, file_path: str) -> None:
"""
Generate embeddings from the PDF and store them in the vector store.
Args:
file_path (str): Path to the PDF file.
"""
data = self.load_pdf(file_path)
chunks = self.split_text(data)
embeddings = OllamaEmbeddings(model=self.embedding_model, show_progress=True)
self.embeddings_data = [(chunk, embeddings.embed_documents([chunk.page_content])[0]) for chunk in chunks]
self.add_to_vector_store(chunks)
def get_retriever(self) -> Any:
"""
Get a retriever for querying the vector database.
Returns:
Any: The retriever object.
"""
if self.vector_db:
return self.vector_db.as_retriever()
return None
def delete_collection(self) -> None:
"""
Delete the existing collection from the vector store.
"""
if self.vector_db:
self.vector_db.delete_collection()
self.retriever = None # Clear the retriever when the collection is deleted
def set_retriever(self, retriever: Any) -> None:
"""
Set the retriever object.
Args:
retriever (Any): The retriever object.
"""
self.retriever = retriever
def query_llm(self, input_data: Dict[str, Any]) -> str:
"""
Query the language model to get an answer based on input data.
Args:
input_data (Dict[str, Any]): The input data for the query.
Returns:
str: The answer from the language model.
"""
llm = ChatOllama(model=self.llm_model)
template = """Answer the question based ONLY on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return chain.invoke(input_data)
def upload_and_process_pdf(self, file_path: str) -> None:
"""
Upload and process the PDF, generating and storing embeddings.
Args:
file_path (str): Path to the uploaded PDF file.
"""
self.generate_and_store_embeddings(file_path)
retriever = self.get_retriever()
if retriever:
llm = ChatOllama(model=self.llm_model)
self.set_retriever(MultiQueryRetriever.from_llm(retriever, llm, self._query_prompt_template()))
def _query_prompt_template(self) -> PromptTemplate:
"""
Define the prompt template for multi-query retriever.
Returns:
PromptTemplate: The prompt template object.
"""
return PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate five
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}"""
)
def get_embeddings_data(self) -> List[Dict[str, Any]]:
"""
Get the stored embeddings data.
Returns:
List[Dict[str, Any]]: A list of embeddings data with text.
"""
return [{"text": chunk.page_content, "embedding": embedding} for chunk, embedding in self.embeddings_data]
def get_answer(self, question: str) -> str:
"""
Get an answer from the language model based on the question.
Args:
question (str): The user's question.
Returns:
str: The answer from the language model.
"""
answer = self.query_llm({"question": question})
self.history.append({"question": question, "answer": answer}) # Add to history
return answer
def get_history(self) -> List[Dict[str, str]]:
"""
Get the question and answer history.
Returns:
List[Dict[str, str]]: List of questions and answers.
"""
return self.history
# Streamlit
def main() -> None:
st.title("PDF Question Answering System")
st.write("Upload a PDF and ask questions based on its content.")
# Initialize the PDFQuestionAnsweringSystem with a fixed storage path
system = PDFQuestionAnsweringSystem()
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
if uploaded_file is not None:
# Save the uploaded PDF
with open("uploaded_file.pdf", "wb") as f:
f.write(uploaded_file.getbuffer())
system.upload_and_process_pdf("uploaded_file.pdf")
st.write("Embeddings generated and stored successfully.")
# Option to download the embeddings data
embeddings_data = system.get_embeddings_data()
embeddings_json = json.dumps(embeddings_data, indent=4)
st.download_button("Download Embeddings Data", embeddings_json, "embeddings.json")
question = st.text_input("Ask a question based on the uploaded PDF", key="question_input")
if st.button("Get Answer"):
if question:
answer = system.get_answer(question)
st.write(answer)
# Clear the input field
st.session_state["question_input"] = ""
# Display history
st.write("## Question and Answer History")
for entry in system.get_history():
st.write(f"**Question:** {entry['question']}")
st.write(f"**Answer:** {entry['answer']}")
st.write("---")
if st.button("Clear Database"):
system.delete_collection()
st.write("Vector database cleared.")
if __name__ == "__main__":
main()