Skip to content

Commit

Permalink
Merge pull request #1 from bilge-ince/aidb-rag
Browse files Browse the repository at this point in the history
Implement aidb into RAG application
  • Loading branch information
bilge-ince authored Jul 31, 2024
2 parents b1cbbff + 69f415c commit fc63991
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 165 deletions.
5 changes: 3 additions & 2 deletions .env-example
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# DATABASE
DB_NAME=vector_test
DB_USER=postgres
DB_PASSWORD=postgres
DB_PASSWORD=password
DB_HOST=localhost
DB_PORT=5432
DB_PORT=15432

# MODEl
AIDB_MODEL_NAME=all-MiniLM-L6-v2
MODEL_NAME=mistralai/Mistral-7B-Instruct-v0.2
TOKENIZER_NAME=mistralai/Mistral-7B-Instruct-v0.2
HUGGING_FACE_ACCESS_TOKEN=
51 changes: 14 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# pgvector-rag
An application to demonstrate how can you make a RAG using pgvector and PostgreSQL
# aidb-rag
An application to demonstrate how can you make a RAG using EDB's aidb and PostgreSQL.

![Sample Chat Console Output](/imgs/chat%20console.png)

## Requirements
- Python3
- PostgreSQL
- pgvector
- aidb

## Install

Clone the repository

```
git clone [email protected]:gulcin/pgvector-rag.git
cd pgvector-rag
git clone [email protected]:gulcin/aidb-rag-app.git
cd aidb-rag-app
```

Install Dependencies
Expand All @@ -31,10 +33,16 @@ cp .env-example .env

## Run

First run your `aidb` extension by following the step by step installation guide: https://www.enterprisedb.com/docs/edb-postgres-ai/ai-ml/install-tech-preview/

Make sure your aidb extension is ready to accept connections. Then you can continue as follows:

```
python app.py --help
usage: app.py [-h] {create-db,import-data,chat} ...
usage: app.py [-h] {create-db,import-data,chat} {data_source}
e.g: python app.py import-data sample.pdf
Application Description
Expand All @@ -49,34 +57,3 @@ Subcommands:
chat Use chat feature
```

## Run UI

We use Streamlit for creating a simple Graphical User Interface for our pgvector-rag app.

To be able to run Streamlit please do the following:

```
pip install streamlit
```

**Add keys/secrets to Streamlit secrets**

If you need to store secrets that Streamlit app will use, you can do this by creating
`.streamlit/secrets.toml` file under Streamlit directory and adding lines like following:

```
# .streamlit/secrets.toml
OPENAI_API_KEY = "YOUR_API_KEY"
```
**Run Streamlit app for generating UI**

```
streamlit run chatgptui.py
```
You can create as many apps you'd like and place them under Streamlit directory,
edit the keys if needed and run them like described above.





58 changes: 33 additions & 25 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,40 @@ def main():

args = parser.parse_args()

if hasattr(args, "func"):
if torch.cuda.is_available():
device = "cuda"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
if args.command==Command.CHAT.value:
if hasattr(args, "func"):
if torch.cuda.is_available():
device = "cuda"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
dtype = torch.float16
elif torch.backends.mps.is_available():
device = "mps"
bnb_config = None
dtype = torch.float16 # MPS supports float16å

else:
device = "cpu"
bnb_config = None
tokenizer = AutoTokenizer.from_pretrained(
os.getenv("TOKENIZER_NAME"),
token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
)
else:
device = "cpu"
bnb_config = None

tokenizer = AutoTokenizer.from_pretrained(
os.getenv("TOKENIZER_NAME"),
token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
)
model = AutoModelForCausalLM.from_pretrained(
os.getenv("MODEL_NAME"),
token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
quantization_config=bnb_config,
device_map=device,
torch_dtype=torch.float16,
)

args.func(args, model, device, tokenizer)
model = AutoModelForCausalLM.from_pretrained(
os.getenv("MODEL_NAME"),
token=os.getenv("HUGGING_FACE_ACCESS_TOKEN"),
quantization_config=bnb_config,
device_map=device,
torch_dtype=torch.float16,
)

args.func(args, model, device, tokenizer)
elif ((args.command==Command.IMPORT_DATA.value) or (args.command==Command.CREATE_DB.value)):
args.func(args)
else:
print("Invalid command. Use '--help' for assistance.")

Expand Down
2 changes: 1 addition & 1 deletion commands/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def chat(args, model, device, tokenizer):
if question.lower() == "exit":
break

answer = rag_query(tokenizer=tokenizer, model=model, device=device, query=question)
answer = rag_query(tokenizer=tokenizer, model=model, device=device, query=question, topk=5)

print(f"You Asked: {question}")
print(f"Answer: {answer}")
Expand Down
6 changes: 3 additions & 3 deletions commands/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import psycopg2


def create_db(args, model, device, tokenizer):
def create_db(args):
db_config = {
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
Expand Down Expand Up @@ -32,12 +32,12 @@ def create_db(args, model, device, tokenizer):
conn.autocommit = True

cursor = conn.cursor()
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
cursor.execute("CREATE EXTENSION IF NOT EXISTS aidb cascade;")
cursor.close()

cursor = conn.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS embeddings (id serial PRIMARY KEY, doc_fragment text, embeddings vector(4096));"
"CREATE TABLE IF NOT EXISTS documents (id text PRIMARY KEY, doc_fragment text);"
)
cursor.close()

Expand Down
20 changes: 6 additions & 14 deletions commands/import_data.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
import numpy as np

from db import get_connection
from embedding import generate_embeddings, read_pdf_file


def import_data(args, model, device, tokenizer):
def import_data(args):
data = read_pdf_file(args.data_source)

embeddings = [
generate_embeddings(tokenizer=tokenizer, model=model, device=device, text=line)
for line in data
]

conn = get_connection()
cursor = conn.cursor()

# Store each embedding in the database
for i, (doc_fragment, embedding) in enumerate(embeddings):
for i, (doc_fragment) in enumerate(data):
cursor.execute(
"INSERT INTO embeddings (id, doc_fragment, embeddings) VALUES (%s, %s, %s)",
(i, doc_fragment, embedding[0]),
"INSERT INTO documents (id, doc_fragment) VALUES (%s, %s)",
(i, doc_fragment),
)
conn.commit()

generate_embeddings()
print(
"import-data command executed. Data source: {}".format(
args.data_source
)
)

30 changes: 20 additions & 10 deletions embedding.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
# importing all the required modules
import os
import PyPDF2
import torch
from transformers import pipeline
from db import get_connection

def generate_embeddings(tokenizer, model, device, text):
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
).to(device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
return text, outputs.hidden_states[-1].mean(dim=1).tolist()
def generate_embeddings():
conn = get_connection()
cursor = conn.cursor()

cursor.execute(f"""
SELECT aidb.create_pg_retriever(
'documents_embeddings',
'public',
'id',
'{os.getenv("AIDB_MODEL_NAME")}',
'text',
'documents',
ARRAY['id', 'doc_fragment'],
FALSE);""")
cursor.execute("""
SELECT aidb.refresh_retriever('documents_embeddings');""")
conn.commit()
return None


def read_pdf_file(pdf_path):
Expand Down
Binary file added imgs/chat console.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 13 additions & 28 deletions rag.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from itertools import chain
import torch
from pgvector.psycopg2 import register_vector

from db import get_connection
from embedding import generate_embeddings

from pgvector.psycopg2 import register_vector

template = """<s>[INST]
You are a friendly documentation search bot.
Expand All @@ -20,38 +13,30 @@
Answer:
"""

def get_retrieval_condition(query_embedding, threshold=0.7):
def get_retrieval_condition(query_embedding, topk):
# Convert query embedding to a string format for SQL query
query_embedding_str = ",".join(map(str, query_embedding))

# SQL condition for cosine similarity
condition = f"(embeddings <=> '{query_embedding_str}') < {threshold} ORDER BY embeddings <=> '{query_embedding_str}'"
return condition


def rag_query(tokenizer, model, device, query):
# Generate query embedding
query_embedding = generate_embeddings(
tokenizer=tokenizer, model=model, device=device, text=query
)[1]

# Retrieve relevant embeddings from the database
retrieval_condition = get_retrieval_condition(query_embedding)

# # SQL condition for cosine similarity
# condition = f"(embeddings <=> '{query_embedding_str}') < {threshold} ORDER BY embeddings <=> '{query_embedding_str}'"
conn = get_connection()
register_vector(conn)
cursor = conn.cursor()
cursor.execute(
f"SELECT doc_fragment FROM embeddings WHERE {retrieval_condition} LIMIT 5"
)
retrieved = cursor.fetchall()
f"""SELECT data from aidb.retrieve('{query_embedding_str}', {topk}, 'documents_embeddings');"""
)
results = cursor.fetchall()
rag_query = ' '.join([row[0] for row in results])
return rag_query

rag_query = ' '.join([row[0] for row in retrieved])

def rag_query(tokenizer, model, device, query, topk):
# Retrieve relevant embeddings from the database
rag_query = get_retrieval_condition(query, topk)
query_template = template.format(context=rag_query, question=query)

input_ids = tokenizer.encode(query_template, return_tensors="pt")

# Generate the response
generated_response = model.generate(input_ids.to(device), max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
model.generation_config.pad_token_id = tokenizer.pad_token_id
generated_response = model.generate(input_ids.to(device), max_new_tokens=100)
return tokenizer.decode(generated_response[0][input_ids.shape[-1]:], skip_special_tokens=True)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ psycopg2
transformers
torch
black
pgvector
PyPDF2
bitsandbytes
accelerate
Binary file added sample_k_means.pdf
Binary file not shown.
44 changes: 0 additions & 44 deletions streamlit/chatgptui.py

This file was deleted.

Empty file removed streamlit/llamaindex.py
Empty file.

0 comments on commit fc63991

Please sign in to comment.